FEDERATED TRAINING OF DUAL ENCODING MODELS ON SMALL NON-IID CLIENT DATASETS

Abstract

Dual encoding models that encode a pair of inputs are widely used for representation learning. Many approaches train dual encoding models by maximizing agreement between pairs of encodings on centralized training data. However, in many scenarios, datasets are inherently decentralized across many clients (user devices or organizations) due to privacy concerns, motivating federated learning. In this work, we focus on federated training of dual encoding models on decentralized data composed of many small, non-IID (independent and identically distributed) client datasets. We show that existing approaches that work well in centralized settings perform poorly when naively adapted to this setting using federated averaging. We observe that, we can simulate large-batch loss computation on individual clients for loss functions that are based on encoding statistics. Based on this insight, we propose a novel federated training approach, Distributed Cross Correlation Optimization (DCCO), which trains dual encoding models using encoding statistics aggregated across clients, without sharing individual data samples. Our experimental results on two datasets demonstrate that the proposed DCCO approach outperforms federated variants of existing approaches by a large margin.

1. INTRODUCTION

Dual encoding models (see Fig. 1 ) are a class of models that generate a pair of encodings for a pair of inputs, either by processing both inputs using the same network or using two different networks. These models have been widely successful in self-supervised representation learning of unlabeled unimodal data (Chen et al., 2020a; Chen & He, 2021; Zbontar et al., 2021; He et al., 2020; Grill et al., 2020) , and are also a natural choice for representation learning of paired multi-modal data (Jia et al., 2021; Radford et al., 2021; Bardes et al., 2022) . While several approaches exist for training dual encoding models in centralized settings (where the entire training data is present on a central server), training these models on decentralized datasets is less explored. Due to data privacy concerns, learning from decentralized datasets is becoming increasingly important. Federated learning (McMahan et al., 2017 ) is a widely-used approach for learning from decentralized datasets without transferring raw data to a central server. In each round of federated training, each participating client computes a model update using its local data, and then a central server aggregates the client model updates and performs a global model update. In many real-world scenarios, individual client datasets are small and non-IID (independent and identically distributed), e.g., in cross-device federated settings (Kairouz et al., 2021; Wang et al., 2021) . For example, in the context of mobile medical apps such as Aysa (AskAysa, 2022) and DermAssist (DermAssist, 2022) , each user contributes only a few (1-3) images. Motivated by this, we focus on federated training of dual encoding models on decentralized data composed of a large number of small, non-IID client datasets. Recently, several highly successful approaches have been proposed for training dual encoding models in centralized settings based on contrastive losses (He et al., 2020; Chen et al., 2020a; b), statistics-based losses (Zbontar et al., 2021; Bardes et al., 2022) , and predictive losses combined with batch normalization (Ioffe & Szegedy, 2015) and stop-gradient operation (Grill et al., 2020; Chen & He, 2021) . One way to enable federated training of dual encoding models is to adapt these existing approaches using the Federated Averaging (FedAvg) strategy of McMahan et al. (2017) . As

