FEDFA: FEDERATED FEATURE AUGMENTATION

Abstract

Federated learning is a distributed paradigm that allows multiple parties to collaboratively train deep models without exchanging the raw data. However, the data distribution among clients is naturally non-i.i.d., which leads to severe degradation of the learnt model. The primary goal of this paper is to develop a robust federated learning algorithm to address feature shift in clients' samples, which can be caused by various factors, e.g., acquisition differences in medical imaging. To reach this goal, we propose FEDFA to tackle federated learning from a distinct perspective of federated feature augmentation. FEDFA is based on a major insight that each client's data distribution can be characterized by statistics (i.e., mean and standard deviation) of latent features; and it is likely to manipulate these local statistics globally, i.e., based on information in the entire federation, to let clients have a better sense of the underlying distribution and therefore alleviate local data bias. Based on this insight, we propose to augment each local feature statistic probabilistically based on a normal distribution, whose mean is the original statistic and variance quantifies the augmentation scope. Key to our approach is the determination of a meaningful Gaussian variance, which is accomplished by taking into account not only biased data of each individual client, but also underlying feature statistics characterized by all participating clients. We offer both theoretical and empirical justifications to verify the effectiveness of FEDFA. Our code is available at https://github.com/tfzhou/FedFA.

1. INTRODUCTION

Federated learning (FL) (Konečnỳ et al., 2016) is an emerging collaborative training framework that enables training on decentralized data residing devices like mobile phones. It comes with the promise of training centralized models using local data points such that the privacy of participating devices is preserved, and has attracted significant attention in critical fields like healthcare or finance. Since data come from different users, it is inevitable that the data of each user have a different underlying distribution, incurring large heterogeneity (non-iid-ness) among users' data. In this work, we focus on feature shift (Li et al., 2020b) , which is common in many real-world cases, like medical data acquired from different medical devices or natural image collected in diverse environments. While the problem of feature shift has been studied in classical centralized learning tasks like domain generalization, little is understood how to tackle it in federated learning. (Li et al., 2020b; Reisizadeh et al., 2020; Jiang et al., 2022; Liu et al., 2020a) are rare exceptions. FEDROBUST (Reisizadeh et al., 2020) and FEDBN (Li et al., 2020b) solve the problem through client-dependent learning by either fitting the shift with a client-specific affine distribution or learning unique BN parameters for each client. However, these algorithms may still suffer significant local dataset bias. Other works (Qu et al., 2022; Jiang et al., 2022; Caldarola et al., 2022) learn robust models by adopting Sharpness Aware Minimization (SAM) (Foret et al., 2021) as the local optimizer, which, however, doubles the computational cost compared to SGD or Adam. In addition to model optimization, FEDHARMO (Jiang et al., 2022) has investigated specialized image normalization techniques to mitigate feature shift in medical domains. Despite the progress, there leaves an alternative space -data augmentation -largely unexplored in federated learning, even though it has been extensively studied in centralized setting to impose regularization and improve generalizibility (Zhou et al., 2021; Zhang et al., 2018) . While seemingly straightforward, it is non-trivial to perform effective data augmentation in federated learning because users have no direct access to external data of other users. Simply applying conventional augmentation techniques to each client is sub-optimal since without injecting global information, augmented samples will most likely still suffer local dataset bias. To address this, FED-MIX (Yoon et al., 2021) generalizes MIXUP (Zhang et al., 2018) into federated learning, by mixing averaged data across clients. The method performs augmentation in the input level, which is naturally weak to create complicated and meaningful semantic transformations, e.g., make-bespectacled. Moreover, allowing exchange of averaged data will suffer certain levels of privacy issues. In this work, we introduce a novel federation-aware augmentation technique, called FedFA, into federated learning. FEDFA is based on the insight that statistics of latent features can capture essential domain-aware characteristics (Huang & Belongie, 2017; Zhou et al., 2021; Li et al., 2022a; b; 2021a) , thus can be treated as "features of participating client". Accordingly, we argue that the problem of feature shift in FL, no matter the shift of each local data distribution from the underlying distribution, or local distribution differences among clients, even test-time distribution shift, can be interpreted as the shift of feature statistics. This motivates us to directly addressing local feature statistic shift by incorporating universal statistic characterized by all participants in the federation. FEDFA instantiates the idea by online augmenting feature statistics of each sample during local model training, so as to make the model robust to certain changes of "features of participating client". Concretely, we model the augmentation procedure in a probabilistic manner via a multivariate Gaussian distribution. The Gaussian mean is fixed to the original statistic, and variance reflects the potential local distribution shift. In this manner, novel statistics can be effortlessly synthesized by drawing samples from the Gaussian distribution. For effective augmentation, we determine a reasonable variance based on not only variances of feature statistics within each client, but also universal variances characterized by all participating clients. The augmentation in FEDFA allows each local model to be trained over samples drawn from more diverse feature distributions, facilitating local distribution shift alleviation and client-invariant representation learning, eventually contributing to a better global model. FEDFA is a conceptually simple but surprisingly effective method. It is non-parametric, requires negligible additional computation and communication costs, and can be seamlessly incorporated into arbitrary CNN architectures. We propose both theoretical and empirical insights. Theoretically, we show that FEDFA implicitly introduces regularization to local model learning by regularizing the gradients of latent representations, weighted by variances of feature statistics estimated from the entire federation. Empirically, we demonstrate that FEDFA (1) works favorably with extremely small local datasets; (2) shows remarkable generalization performance to unseen test clients outside of the federation; (3) outperforms traditional data augmentation techniques by solid margins, and can complement them quite well in the federated learning setup.

2.1. PRELIMINARY: FEDERATED LEARNING

We assume a standard federated learning setup with a server that can transmit and receive messages from M client devices. Each client m ∈ [M ] has access to N m training instances (x i , y i ) Nm i=1 in the form of image x i ∈ X and corresponding labels y i ∈ Y that are drawn i.i.d. from a device-indexed joint distribution, i.e., (x i , y i ) ∼ P m (x, y). The goal of standard federated learning is to learn a deep neural network: f (w g , w h ) g(w g ) • h(w h ), where h : X → Z is a feature extractor with K convolutional stages: h = h K • h K-1 • • • • • h 1 , and g : Z → Y is a classifier. To learn network parameters w = {w g , w h }, the empirical risk minimization (ERM) is widely used: L ERM (w) 1 M m∈[M ] L ERM m (w), where L ERM m (w) = E (x i ,y i )∼Pm [ i(g • h(xi), yi; w)]. Here the global objective L ERM is decomposable as a sum of device-level empirical loss objectives (i.e., {L ERM m } m ). Each L ERM m is computed based on a per-data loss function i . Due to the separation of clients' data, L ERM (w) cannot be solved directly. FEDAVG (McMahan et al., 2017) 



is a leading algorithm to address this. It starts with client training of all the clients in parallel, with each client optimizing L ERM m independently. After local client training, FEDAVG performs model aggregation to average all client models into a updated global model, which will be distributed back to the clients for the next round of client training. Here the client training objective in FEDAVG is equivalent to empirically approximating the local distribution P m by a finite N m number of examples, i.e., P e m (x, y) = 1 /Nm Nm i=1 δ(x = x i , y = y i ), where δ(x = x i , y = y i ) is a Dirac mass centered at (x i , y i ).

