NERN -LEARNING NEURAL REPRESENTATIONS FOR NEURAL NETWORKS

Abstract

Neural Representations have recently been shown to effectively reconstruct a wide range of signals from 3D meshes and shapes to images and videos. We show that, when adapted correctly, neural representations can be used to directly represent the weights of a pre-trained convolutional neural network, resulting in a Neural Representation for Neural Networks (NeRN). Inspired by coordinate inputs of previous neural representation methods, we assign a coordinate to each convolutional kernel in our network based on its position in the architecture, and optimize a predictor network to map coordinates to their corresponding weights. Similarly to the spatial smoothness of visual scenes, we show that incorporating a smoothness constraint over the original network's weights aids NeRN towards a better reconstruction. In addition, since slight perturbations in pre-trained model weights can result in a considerable accuracy loss, we employ techniques from the field of knowledge distillation to stabilize the learning process. We demonstrate the effectiveness of NeRN in reconstructing widely used architectures on CIFAR-10, CIFAR-100, and ImageNet. Finally, we present two applications using NeRN, demonstrating the capabilities of the learned representations.

1. INTRODUCTION

In the last decade, neural networks have proven to be very effective at learning representations over a wide variety of domains. Recently, NeRF (Mildenhall et al., 2020) demonstrated that a relatively simple neural network can directly learn to represent a 3D scene. This is done using the general method for neural representations, where the task is modeled as a prediction problem from some coordinate system to an output that represents the scene. Once trained, the scene is encoded in the weights of the neural network and thus novel views can be rendered for previously unobserved coordinates. NeRFs outperformed previous view synthesis methods, but more importantly, offered a new view on scene representation. Following the success of NeRF, there have been various attempts to learn neural representations on other domains as well. In SIREN (Sitzmann et al., 2020) it is shown that neural representations can successfully model images when adapted to handle high frequencies. NeRV (Chen et al., 2021) utilizes neural representations for video encoding, where the video is represented as a mapping from a timestamp to the pixel values of that specific frame. In this paper, we explore the idea of learning neural representations for pre-trained neural networks. We consider representing a Convolutional Neural Network (CNN) using a separate predictor neural network, resulting in a neural representation for neural networks, or NeRN. We model this task as a problem of mapping each weight's coordinates to its corresponding values in the original network. Specifically, our coordinate system is defined as a (Layer, Filter, Channel) tuple, denoted by (l, f, c), where each coordinate corresponds to the weights of a k × k convolutional kernel. NeRN is trained to map each input's coordinate back to the original kernel weights. One can then reconstruct the original network by querying NeRN over all possible coordinates. While a larger predictor network can trivially learn to overfit a smaller original network, we show that successfully creating a compact implicit representation is not trivial. To achieve this, we propose methods for introducing smoothness over the learned signal, i.e. the original network weights, either by applying a regularization term in the original network training or by applying post-training permutations over the original network weights. In addition, we design a training scheme inspired by knowledge distillation methods that allows for a better and more stable optimization process. Similarly to other neural representations, a trained NeRN represents the weights of the specific neural network it was trained on, which to the best of our knowledge differs from previous weight prediction papers such as Ha et al. ( 2016 Successfully learning a NeRN provides some additional interesting insights. For example, a NeRN with limited capacity must prioritize the original weights during training. This can then be explored, using the reconstruction error, to study importance of different weights. In addition to our proposed method and extensive experiments, we provide a scalable framework for NeRN built using PyTorch (Paszke et al., 2019) that can be extended to support new models and tasks. We hope that our proposed NeRN will give a new perspective on neural networks for future research.

2. RELATED WORK

Neural representations have recently proven to be a powerful tool in representing various signals using coordinate inputs fed into an MLP (multilayer perceptron). The superiority of implicit 3D shape neural representations (Sitzmann et al., 2019; Jiang et al., 2020; Peng et al., 2020; Chabra et al., 2020; Mildenhall et al., 2020) 2020). Following NeRF's success, additional applications rose for neural representations such as image compression (Dupont et al., 2021 ), video encoding (Chen et al., 2021) , camera pose estimation (Yen-Chen et al., 2021) and more. Some of these redesigned the predictor network to complement the learned signal. For example, Chen et al. (2021) adopted a CNN for frame prediction. In our work, we adopt a simple MLP while incorporating additional methods to fit the characteristics of convolutional weights. Weight prediction refers to generating a neural network's weights using an additional predictor network. In Ha et al. (2016) the weights of a larger network are predicted using a smaller internal network, denoted as a HyperNetwork. The HyperNetwork is trained to directly solve the task, while also learning the input vectors for parameter prediction. Deutsch (2018) followed this idea by exploring the trade-off between accuracy and diversity in parameter prediction. In contrast, we aim to directly represent a pre-trained neural network, using fixed inputs. Several works have explored the idea of using a model dataset for weight prediction. For instance, Schürholt et al. (2021) proposes a representation learning approach for predicting hyperparameters and downstream performance. Knowledge distillation is mostly used for improving the performance of a compressed network, given a pre-trained larger teacher network. There are two main types of knowledge used in studentteacher learning. First, response-based methods (Ba & Caruana, 2014; Hinton et al., 2014; Chen et al., 2017; 2019) focus on the output classification logits. Second, feature-based methods (Romero et al., 2015; Zagoruyko & Komodakis, 2017) focus on feature maps (activations) throughout the network. The distillation scheme can be generally categorized as offline (Zagoruyko & Komodakis, 2017; Huang & Wang, 2017; Passalis & Tefas, 2018; Heo et al., 2019; Mirzadeh et al., 2020; Li et al., 2020) or online (Zhang et al., 2018b; Chen et al., 2020; Xie et al., 2019) . In our work, we leverage offline response and feature-based knowledge distillation for guiding the learning process.

3. METHOD

In this work we focus on representing convolutional classification networks. Our overall pipeline is presented in Figure 1 , with extended details below on the design choices and training of NeRN.

3.1. DESIGNING NERNS

Similar to other neural representations, at its core, NeRN is composed of a simple neural network, whose input is some positional embedding representing a weight coordinate in the original network



); Schürholt et al. (2021); Knyazev et al. (2021); Schürholt et al. (2022). We demonstrate NeRN's reconstruction results on several classification benchmarks.

over previous representations such as grids or meshes has been demonstrated in Park et al. (2019); Chen & Zhang (2019); Genova et al. (

Schürholt et al. (2022) explored a similar idea for weight initialization while promoting diversity. Zhang et al. (2018a); Knyazev et al. (2021) leverage a GNN (graph neural network) to predict the parameters of a previously unseen architecture by modeling it as a graph input.

