METANORM: LEARNING TO NORMALIZE FEW-SHOT BATCHES ACROSS DOMAINS

Abstract

Batch normalization plays a crucial role when training deep neural networks. However, batch statistics become unstable with small batch sizes and are unreliable in the presence of distribution shifts. We propose MetaNorm, a simple yet effective meta-learning normalization. It tackles the aforementioned issues in a unified way by leveraging the meta-learning setting and learns to infer adaptive statistics for batch normalization. MetaNorm is generic, flexible and model-agnostic, making it a simple plug-and-play module that is seamlessly embedded into existing meta-learning approaches. It can be efficiently implemented by lightweight hypernetworks with low computational cost. We verify its effectiveness by extensive evaluation on representative tasks suffering from the small batch and domain shift problems: few-shot learning and domain generalization. We further introduce an even more challenging setting: few-shot domain generalization. Results demonstrate that MetaNorm consistently achieves better, or at least competitive, accuracy compared to existing batch normalization methods.

1. INTRODUCTION

Batch normalization (Ioffe & Szegedy, 2015) is crucial for training neural networks, and with its variants, e.g., layer normalization (Ba et al., 2016) , group normalization (Wu & He, 2018) and instance normalization (Ulyanov et al., 2016) , has thus become an essential part of the deep learning toolkit (Bjorck et al., 2018; Luo et al., 2018a; Yang et al., 2019; Jia et al., 2019; Luo et al., 2018b; Summers & Dinneen, 2020) . Batch normalization helps stabilize the distribution of internal activations when a model is being trained. Given a mini-batch B, the normalization is conducted along each individual feature channel for 2D convolutional neural networks. During training, the batch normalization moments are calculated as follows: µ B = 1 M M i=1 a i , σ 2 B = 1 M M i=1 (a i -µ B ) 2 , ( ) where a i indicates the i-th element of the M activations in the batch, M = |B| × H × W , in which H and W are the height and width of the feature map in each channel. We can now apply the normalization statistics to each activation: a i ← BN(a i ) ≡ γâ i + β, where, âi = a i -µ B σ 2 B + , where γ and β are parameters learned during training, is a small scalar to prevent division by 0, and operations between vectors are element-wise. At test time, the standard practice is to normalize activations using the moving average over mini-batch means µ B and variance σ 2 B . Batch normalization is based on an implicit assumption that the samples in the dataset are independent and identically distributed. However, this assumption does not hold in challenging settings like few-shot learning and domain generalization. In this paper, we strive for batch normalization when batches are of small size and suffer from distributions shifts between source and target domains. Batch normalization for few-shot learning and domain generalization problems have so far been considered separately, predominantly in a meta-learning setting. For few-shot meta-learning (Finn et al., 2017; Gordon et al., 2019) , most existing methods rely critically on transductive batch normalization, except those based on prototypes (Snell et al., 2017; Allen et al., 2019; Zhen et al., 2020a) . However, the nature of transductive learning restricts its application due to the requirement to sample from the test set. To address this issue, Bronskill et al. (2020) proposes TaskNorm, which leverages other statistics from both layer and instance normalization. As a non-transductive normalization approach, it achieves impressive performance and outperforms conventional batch normalization (Ioffe & Szegedy, 2015) . However, its performance is not always performing better than transductive batch normalization. Meanwhile, domain generalization (Muandet et al., 2013; Balaji et al., 2018; Li et al., 2017a; b) suffers from distribution shifts from training to test, which makes it problematic to directly apply statistics calculated from a seen domain to test data from unseen domains (Wang et al., 2019; Seo et al., 2019) . Recent works deal with this problem by learning a domain specific normalization (Chang et al., 2019; Seo et al., 2019) or a transferable normalization in place of existing normalization techniques (Wang et al., 2019) . We address the batch normalization challenges for few-shot classification and domain generalization in a unified way by learning a new batch normalization under the meta-learning setting. We propose MetaNorm, a simple but effective meta-learning normalization. We leverage the metalearning setting and learn to infer normalization statistics from data, instead of applying direct calculations or blending various normalization statistics. MetaNorm is a general batch normalization approach, which is model-agnostic and serves as a plug-and-play module that can be seamlessly embedded into existing meta-learning approaches. We demonstrate its effectiveness for few-shot classification and domain generalization, where it learns task-specific statistics from limited data samples in the support set for each few-shot task; and it can also learn to generate domain-specific statistics from the seen source domains for unseen target domains. We verify the effectiveness of MetaNorm by extensive evaluation on few-shot classification and domain generalization tasks. For few-shot classification, we experiment with representative gradient, metric and model-based meta-learning approaches on fourteen benchmark datasets. For domain generalization, we evaluate the model on three widely-used benchmarks for cross-domain visual object classification. Last but not least, we introduce the challenging new task of few-shot domain generalization, which combines the challenges of both few-shot learning and domain generalization. The experimental results demonstrate the benefit of MetaNorm compared to existing batch normalizations.

2. RELATED WORKS

Transductive Batch Normalization For conventional batch normalization under supervised settings, i.i.d. assumptions about the data distribution imply that estimating moments from the training set will provide appropriate normalization statistics for test data. However, in the meta-learning scenario data points are only assumed to be i.i.d. within a specific task. Therefore, it is critical to select the moments when batch normalization is applied to support and query set data points during meta training and meta testing. Hence, in the recent meta-learning literature the running moments are no longer used for normalization at meta-test time, but instead replaced with support/query set statistics. These statistics are used for normalization, both at meta-train and meta-test time. This approach is referred to as transductive batch normalization (TBN) (Bronskill et al., 2020) . Competitive meta-learning methods (e.g., Gordon et al., 2019; Finn et al., 2017; Zhen et al., 2020b) rely on TBN to achieve state-of-the-art performance. However, there are two critical problems with TBN. First, TBN is sensitive to the distribution over the query set used during meta-training, and as such is less generally applicable than non-transductive learning. Second, TBN uses extra information for multiple test samples, compared to non-transductive batch normalization at prediction time, which could be problematic as we are not guaranteed to have a set of test samples available during training in practical applications. In contrast, MetaNorm is a non-transductive normalization. It generates statistics from the support set only, without relying on query samples, making it more practical. Meta Batch Normalization To address the problem of transductive batch normalization and improve conventional batch normalization, meta-batch normalization (MetaBN) was introduced (Triantafillou et al., 2020; Bronskill et al., 2020) . In MetaBN, the support set alone is used to compute the normalization statistics for both the support and query sets at both meta-training and meta-test time. MetaBN is non-transductive since the normalization of a test input does not depend on other test inputs in the query set. However, Bronskill et al. (2020) observe that MetaBN performs less well for small-sized support sets. This leads to high variance in moment estimates, which is similar to the

