METANORM: LEARNING TO NORMALIZE FEW-SHOT BATCHES ACROSS DOMAINS

Abstract

Batch normalization plays a crucial role when training deep neural networks. However, batch statistics become unstable with small batch sizes and are unreliable in the presence of distribution shifts. We propose MetaNorm, a simple yet effective meta-learning normalization. It tackles the aforementioned issues in a unified way by leveraging the meta-learning setting and learns to infer adaptive statistics for batch normalization. MetaNorm is generic, flexible and model-agnostic, making it a simple plug-and-play module that is seamlessly embedded into existing meta-learning approaches. It can be efficiently implemented by lightweight hypernetworks with low computational cost. We verify its effectiveness by extensive evaluation on representative tasks suffering from the small batch and domain shift problems: few-shot learning and domain generalization. We further introduce an even more challenging setting: few-shot domain generalization. Results demonstrate that MetaNorm consistently achieves better, or at least competitive, accuracy compared to existing batch normalization methods.

1. INTRODUCTION

Batch normalization (Ioffe & Szegedy, 2015) is crucial for training neural networks, and with its variants, e.g., layer normalization (Ba et al., 2016) , group normalization (Wu & He, 2018) and instance normalization (Ulyanov et al., 2016) , has thus become an essential part of the deep learning toolkit (Bjorck et al., 2018; Luo et al., 2018a; Yang et al., 2019; Jia et al., 2019; Luo et al., 2018b; Summers & Dinneen, 2020) . Batch normalization helps stabilize the distribution of internal activations when a model is being trained. Given a mini-batch B, the normalization is conducted along each individual feature channel for 2D convolutional neural networks. During training, the batch normalization moments are calculated as follows: µ B = 1 M M i=1 a i , σ 2 B = 1 M M i=1 (a i -µ B ) 2 , ( ) where a i indicates the i-th element of the M activations in the batch, M = |B| × H × W , in which H and W are the height and width of the feature map in each channel. We can now apply the normalization statistics to each activation: a i ← BN(a i ) ≡ γâ i + β, where, âi = a i -µ B σ 2 B + , where γ and β are parameters learned during training, is a small scalar to prevent division by 0, and operations between vectors are element-wise. At test time, the standard practice is to normalize activations using the moving average over mini-batch means µ B and variance σ 2 B . Batch normalization is based on an implicit assumption that the samples in the dataset are independent and identically distributed. However, this assumption does not hold in challenging settings like few-shot learning and domain generalization. In this paper, we strive for batch normalization when batches are of small size and suffer from distributions shifts between source and target domains. Batch normalization for few-shot learning and domain generalization problems have so far been considered separately, predominantly in a meta-learning setting. For few-shot meta-learning (Finn

