CLIENT-AGNOSTIC LEARNING AND ZERO-SHOT ADAP-TATION FOR FEDERATED DOMAIN GENERALIZATION

Abstract

Federated domain generalization (federated DG) aims to learn a client-agnostic global model from various distributed source domains and generalize the model to new clients in completely unseen domains. The main challenges of federated DG are the difficulty of building the global model with local client models from different domains while keeping data private and low generalizability to test clients, where data distribution deviates from those of training clients. To solve these challenges, we present two strategies: (1) client-agnostic learning with mixed instance-global statistics and (2) zero-shot adaptation with estimated statistics. In client-agnostic learning, we first augment local features by using data distribution of other clients via global statistics in the global model's batch normalization layers. This approach allows the generation of diverse domains by mixing local and global feature statistics while keeping data private. Local models then learn client-invariant representations by applying our client-agnostic objectives with the augmented data. Next, we propose a zero-shot adapter to help the learned global model to directly bridge a large domain gap between seen and unseen clients. At inference time, the adapter mixes instance statistics of a test input with global statistics that are vulnerable to distribution shift. With the aid of the adapter, the global model improves generalizability further by reflecting test distribution. We comprehensively evaluate our methods on several benchmarks in federated DG.

1. INTRODUCTION

A huge amount of data is being collected every second from a wide range of IoT devices, and the data have been utilized for building robust deep learning models. Federated learning (FL) has emerged as a promising paradigm to train the model indirectly accessing the distributed data such that it reduces privacy leakage. Pioneering studies such as FedAvg (McMahan et al., 2017) and FedProx (Li et al., 2020) train each local model on its own data while keeping data private and transmit model parameters to the server for obtaining a generalized global model. The parameters from local clients are aggregated in the server, and the server parameters are broadcasted to clients. This process is iteratively performed until the global model converges to a stationary point, and user privacy is ensured by sharing aggregated parameters not data itself with other clients. In real-world scenarios, local data are collected from various domains across clients coming from different characteristics of sensors and surrounding environments. For example, in autonomous driving tasks, each vehicle captures street views and infrastructures differently from others due to variances in camera sensors, region, and other factors. These local data deviates in terms of the distribution in feature space, inducing non-iid data across clients, denoted as domain shift (Li et al., 2021b; Jiang et al., 2021) . Currently, most studies have tried to solve the issues of FL on non-iid data, especially heterogeneous label distribution (Li et al., 2020; Karimireddy et al., 2020; Wang et al., 2020) , but domain shift has not been fully explored in the literature yet. Domain shift also exists between training and test clients. After federated learning, the learned FL model is deployed to new customers outside the federation, e.g., new vehicles or medical centers, where data distribution is shifted from those of clients inside the federation. However, most works only focus on improving model performance of the clients participated in FL, while neglecting generalization on unseen clients. In this paper, we treat federated domain generalization (federated DG), which aims to collaboratively learn a client-agnostic federated model from various distributed source domains and generalize the learned model to new clients in unseen domains, as illustrated in Fig. 1 .

