NEIGHBORHOOD-AWARE NEURAL ARCHITECTURE SEARCH

Abstract

Existing neural architecture search (NAS) methods often return an architecture with good search performance but generalizes poorly to the test setting. To achieve better generalization, we propose a novel neighborhood-aware NAS formulation to identify flat-minima architectures in the search space, with the assumption that flat minima generalize better than sharp minima. The phrase "flat-minima architecture" refers to architectures whose performance is stable under small perturbations in the architecture (e.g., replacing a convolution with a skip connection). Our formulation takes the "flatness" of an architecture into account by aggregating the performance over the neighborhood of this architecture. We demonstrate a principled way to apply our formulation to existing search algorithms, including sampling-based algorithms and gradient-based algorithms. To facilitate the application to gradient-based algorithms, we also propose a differentiable representation for the neighborhood of architectures. Based on our formulation, we propose neighborhood-aware random search (NA-RS) and neighborhood-aware differentiable architecture search (NA-DARTS). Notably, by simply augmenting DARTS (Liu et al., 2019) with our formulation, NA-DARTS finds architectures that perform better or on par with those found by state-of-the-art NAS methods on established benchmarks, including CIFAR-10, CIFAR-100 and ImageNet.

1. INTRODUCTION

The process of automatic neural architecture design, also called neural architecture search (NAS), is a promising technology to improve performance and efficiency for deep learning applications (Zoph & Le, 2017; Zoph et al., 2018; Liu et al., 2019) . NAS methods typically minimize the validation loss to find the optimal architecture. However, directly optimizing such an objective may cause the search algorithm to overfit to the search setting, i.e., finding a solution architecture with good search performance but generalizes poorly to the test setting. This type of overfitting is a result of the differences between the search and test settings, such as the length of training schedules (Zoph & Le, 2017; Zoph et al., 2018) , cross-architecture weight sharing (Liu et al., 2019; Pham et al., 2018) , and the usage of proxy datasets during search (Zoph & Le, 2017; Zoph et al., 2018; Liu et al., 2019) . To achieve better generalization, we propose a novel NAS formulation that searches for "flat-minima architectures", which we define as architectures that perform well under small perturbations of the architecture (Figure 1 ). One example of architectural perturbations is to replace a convolutional operator with a skip connection (identity mapping). Our work takes inspiration from prior work on neural network training, which shows that flat minima of the loss function correspond to network weights with better generalization than sharp ones (Hochreiter & Schmidhuber, 1997) . We show that flat minima in the architecture space also generalize better to a new data distribution than sharp minima (Sec. 3.3). Unlike the standard NAS formulation that directly optimizes single architecture performance, i.e., α * = arg min α∈A f (α), we optimize the aggregated performance over the neighborhood of an architecture: α * = arg min α∈A g (f (N (α))) , where f (•) is a task-specific error metric, α denotes an architecture in the search space A, N (α) denotes the neighborhood of architecture α, and g(•) is an aggregation function (e.g., the mean Figure 1 : Loss landscape visualization of the found architecture. We project architectures (instead of the network weights) onto a 2D plane. The architectures are sampled along two prominent directions (the two axes, λ 0 and λ 1 ), with (0, 0) denotes the found architecture. We see that our found architecture (right) is a much flatter minimum than that found with the standard formulation (left). We provide visualization details in the appendix. function). Note that we overload the notation of the error metric f (•) and define f (•) to return a set of errors when the input is a set of architectures in the neighborhood: f (N (α)) = {f (α ) | α ∈ N (α)}. Common choices for f (•) are validation loss and negative validation accuracy. We will discuss more details of neighborhood N (α) and aggregation function g(•) in the following text. To implement our formulation, one must define the neighborhood N (α) and specify an aggregation function g(•). How to define the neighborhood of an architecture is an open question. One possible method to obtain neighboring architectures is to perturb one or more operations in the architecture and the degree of perturbation defines the scope of the neighborhood. This method can be applied to sampling-based search algorithms, e.g., random search and reinforcement learning. However, it cannot be directly used to generate neighboring architectures for gradient-based search algorithms (a.k.a, differentiable NAS), where the neighboring architectures themselves also need to be differentiable with respect to the architecture being learned. To address this issue, we propose a differentiable representation for the neighborhood of architectures, which makes the objective function differentiable and allows us to apply our formulation to gradient-based algorithms, e.g., DARTS (Liu et al., 2019) . Properly choosing the aggregation function g(•) can help the search algorithm identify flat minima in the search space. Our choice of g(•) (e.g., mean) is inspired by the definition of the flatness/sharpness of local minima in previous work (Chaudhari et al., 2017; Keskar et al., 2017; Dinh et al., 2017) . We summarize our contributions as follows: 1. We propose a neighborhood-aware NAS formulation based on the flat minima assumption, and demonstrate a principled way to apply our formulation to existing search algorithms, including sampling-based algorithms and gradient-based algorithms. We empirically validate our assumption and show that flat-minima architectures generalize better than sharp ones. 2. We propose a neighborhood-aware random search (NA-RS) algorithm and demonstrate its superiority over the standard random search. On NAS-Bench-201 (Dong & Yang, 2020), NA-RS outperforms the standard random search by 1.48% on CIFAR-100 and 1.58% on ImageNet-16-120. 3. We propose a differentiable neighborhood representation so that we can apply our formulation to gradient-based NAS methods. We augment DARTS (Liu et al., 2019) with our formulation and name the proposed method NA-DARTS. Our NA-DARTS outperforms DARTS by 1.18% on CIFAR-100 and 1.2% on ImageNet, and also performs better than or on par with state-of-the-art NAS methods.

2. RELATED WORK

Flat Minima. Hochreiter & Schmidhuber (1997) shows that flat minima of the loss function of neural networks generalize better than sharp minima. Flat minima are used to explain the poor generalization of large-batch methods (Keskar et al., 2017; Yao et al., 2018) , where large-batch methods are shown to be more likely to converge to sharp minima. Chaudhari et al. (2017) propose an ob-

