HYPERPARAMETER OPTIMIZATION THROUGH NEURAL NETWORK PARTITIONING

Abstract

Well-tuned hyperparameters are crucial for obtaining good generalization behavior in neural networks. They can enforce appropriate inductive biases, regularize the model and improve performance -especially in the presence of limited data. In this work, we propose a simple and efficient way for optimizing hyperparameters inspired by the marginal likelihood, an optimization objective that requires no validation data. Our method partitions the training data and a neural network model into K data shards and parameter partitions, respectively. Each partition is associated with and optimized only on specific data shards. Combining these partitions into subnetworks allows us to define the "out-of-training-sample" loss of a subnetwork, i.e., the loss on data shards unseen by the subnetwork, as the objective for hyperparameter optimization. We demonstrate that we can apply this objective to optimize a variety of different hyperparameters in a single training run while being significantly computationally cheaper than alternative methods aiming to optimize the marginal likelihood for neural networks. Lastly, we also focus on optimizing hyperparameters in federated learning, where retraining and cross-validation are particularly challenging.

1. INTRODUCTION

Due to their remarkable generalization capabilities, deep neural networks have become the de-facto models for a wide range of complex tasks. Combining large models, large-enough datasets, and sufficient computing capabilities enable researchers to train powerful models through gradient descent. Regardless of the data regime, however, the choice of hyperparameters -such as neural architecture, data augmentation strategies, regularization, or which optimizer to choose -plays a crucial role in the final model's generalization capabilities. Hyperparameters allow encoding good inductive biases that effectively constrain the models' hypothesis space (e.g., convolutions for vision tasks), speed up learning, or prevent overfitting in the case of limited data. Whereas gradient descent enables the tuning of model parameters, accessing hyperparameter gradients is more complicated. The traditional and general way to optimize hyperparameters operates as follows; 1) partition the dataset into training and validation data 1 , 2) pick a set of hyperparameters and optimize the model on the training data, 3) measure the performance of the model on the validation data and finally 4) use the validation metric as a way to score models or perform search over the space of hyperparameters. This approach inherently requires training multiple models and consequently requires spending resources on models that will be discarded. Furthermore, traditional tuning requires a validation set since optimizing the hyperparameters on the training set alone cannot identify the right inductive biases. A canonical example is data augmentations -they are not expected to improve training set performance, but they greatly help with generalization. In the low data regime, defining a validation set that cannot be used for tuning model parameters is undesirable. Picking the right amount of validation data is a hyperparameter in itself. The conventional rule of thumb to use ∼ 10% of all data can result in significant overfitting, as pointed out by Lorraine et al. (2019) , when one has a sufficiently large number of hyperparameters to tune. Furthermore, a validation set can be challenging to obtain in many use cases. An example is Federated Learning (FL) (McMahan et al., 2017) , which we specifically consider in our experimental section. In FL, each extra training run (for, e.g., a specific hyperparameter setting) comes with additional, non-trivial costs. Different approaches have been proposed in order to address these challenges. Some schemes optimize hyperparameters during a single training run by making the hyperparameters part of the model (e.g., learning dropout rates with concrete dropout (Gal et al., 2017) , learning architectures with DARTs (Liu et al., 2018) and learning data-augmentations with schemes as in Benton et al. (2020) ; van der Wilk et al. ( 2018)). In cases where the model does not depend on the hyperparameters directly but only indirectly through their effect on the value of the final parameters (through optimization), schemes for differentiating through the training procedures have been proposed, such as Lorraine et al. (2019) . Another way of optimizing hyperparameters without a validation set is through the canonical view on model selection (and hence hyperparameter optimization) through the Bayesian lens; the concept of optimizing the marginal likelihood. For deep neural networks, however, the marginal likelihood is difficult to compute. Prior works have therefore developed various approximations for its use in deep learning models and used those to optimize hyperparameters in deep learning, such as those of data augmentation (Schwöbel et al., 2021; Immer et al., 2022) . Still, however, these come at a significant added computational expense and do not scale to larger deep learning problems. This paper presents a novel approach to hyperparameter optimization, inspired by the marginal likelihood, that only requires a single training run and no validation set. Our method is more scalable than previous works that rely on marginal likelihood and Laplace approximations (which require computing or inverting a Hessian (Immer et al., 2021) ) and is broadly applicable to any hierarchical modelling setup.

2. MARGINAL LIKELIHOOD AND PRIOR WORK

In Bayesian inference, the rules of probability dictate how any unknown, such as parameters w or hyperparameters ψ, should be determined given observed data D. Let p(w) be a prior over w and p(D|w, ψ) be a likelihood for D with ψ being the hyperparameters. We are then interested in the posterior given the data p(w|D, ψ) = p(D|w, ψ)p(w)/p(D|ψ). The denominator term p(D|ψ) is known as the marginal likelihood, as it measures the probability of observing the data given ψ, irrespective of the value of w: p(D|ψ) = p(w)p(D|w, ψ)dw. Marginal likelihood has many desirable properties that make it a good criterion for model selection and hyperparameter optimization. It intuitively implements the essence of Occam's Razor principle (MacKay, 2003, § 28) . In the PAC-Bayesian literature, it has been shown that higher marginal likelihood gives tighter frequentist upper bounds on the generalization performance of a given model class (McAllester, 1998; Germain et al., 2016) . It also has close links to cross-validation (see section 2.1) and can be computed from the training data alone. However, computation of the marginal likelihood in deep learning models is usually prohibitively expensive and many recent works have proposed schemes to approximate the marginal likelihood for differentiable model selection (Lyle et al., 2020; Immer et al., 2021; 2022; Schwöbel et al., 2021) .

2.1. "LEARNING SPEED" PERSPECTIVE

Lyle et al. ( 2020); Fong and Holmes (2020) pointed out the correspondence between "learning speed" and marginal likelihood. Namely, the marginal likelihood of the data D conditioned on some hyperparameters ψ can be written as: log p(D|ψ) = k log E p(w|D 1:k-1 ,ψ) [p(D k |w, ψ)] ≥ k E p(w|D 1:k-1 ,ψ) [log p(D k |w, ψ)] (1) where (D 1 , . . . , D C ) is an arbitrary partitioning of the training dataset D into C shards or chunksfoot_0 , and p(w|D 1:k , ψ) is the posterior over parameters of a function f w : X → Y, from the input domain X to the target domain Y after seeing data in shards 1 through k. The right-hand side can be interpreted as a type of cross-validation in which we fix an ordering over the shards and measure the "validation" performance on each shard D k using a model trained on the preceding shards D 1:k-1 .



We use the terms "chunk" and "shard" interchangeably.

