FEDERATED TRAINING OF DUAL ENCODING MODELS ON SMALL NON-IID CLIENT DATASETS

Abstract

Dual encoding models that encode a pair of inputs are widely used for representation learning. Many approaches train dual encoding models by maximizing agreement between pairs of encodings on centralized training data. However, in many scenarios, datasets are inherently decentralized across many clients (user devices or organizations) due to privacy concerns, motivating federated learning. In this work, we focus on federated training of dual encoding models on decentralized data composed of many small, non-IID (independent and identically distributed) client datasets. We show that existing approaches that work well in centralized settings perform poorly when naively adapted to this setting using federated averaging. We observe that, we can simulate large-batch loss computation on individual clients for loss functions that are based on encoding statistics. Based on this insight, we propose a novel federated training approach, Distributed Cross Correlation Optimization (DCCO), which trains dual encoding models using encoding statistics aggregated across clients, without sharing individual data samples. Our experimental results on two datasets demonstrate that the proposed DCCO approach outperforms federated variants of existing approaches by a large margin.

1. INTRODUCTION

Dual encoding models (see Fig. 1 ) are a class of models that generate a pair of encodings for a pair of inputs, either by processing both inputs using the same network or using two different networks. These models have been widely successful in self-supervised representation learning of unlabeled unimodal data (Chen et al., 2020a; Chen & He, 2021; Zbontar et al., 2021; He et al., 2020; Grill et al., 2020) , and are also a natural choice for representation learning of paired multi-modal data (Jia et al., 2021; Radford et al., 2021; Bardes et al., 2022) . While several approaches exist for training dual encoding models in centralized settings (where the entire training data is present on a central server), training these models on decentralized datasets is less explored. Due to data privacy concerns, learning from decentralized datasets is becoming increasingly important. Federated learning (McMahan et al., 2017) is a widely-used approach for learning from decentralized datasets without transferring raw data to a central server. In each round of federated training, each participating client computes a model update using its local data, and then a central server aggregates the client model updates and performs a global model update. In many real-world scenarios, individual client datasets are small and non-IID (independent and identically distributed), e.g., in cross-device federated settings (Kairouz et al., 2021; Wang et al., 2021) . For example, in the context of mobile medical apps such as Aysa (AskAysa, 2022) and DermAssist (DermAssist, 2022), each user contributes only a few (1-3) images. Motivated by this, we focus on federated training of dual encoding models on decentralized data composed of a large number of small, non-IID client datasets. Recently, several highly successful approaches have been proposed for training dual encoding models in centralized settings based on contrastive losses (He et al., 2020; Chen et al., 2020a; b) , statistics-based losses (Zbontar et al., 2021; Bardes et al., 2022) , and predictive losses combined with batch normalization (Ioffe & Szegedy, 2015) and stop-gradient operation (Grill et al., 2020; Chen & He, 2021) . One way to enable federated training of dual encoding models is to adapt these existing approaches using the Federated Averaging (FedAvg) strategy of McMahan et al. (2017) . As In this work we observe that, in the case of statistics-based loss functions, we can simulate largebatch loss computation on each individual (small) client, by first aggregating encoding statistics from many clients and then sharing these aggregated large-batch statistics with all the clients that contributed to them. Based on this observation, we propose a novel approach, Distributed Cross Correlation Optimization (DCCO), for federated training of dual encoding models on small, non-IID client datasets. The proposed approach simulates large-batch training based with the loss function of Zbontar et al. ( 2021), which we refer to as Cross Correlation Optimization (CCO) loss. This is achieved without sharing individual data samples or their encodings between clients. N 𝛉 I X Y F G Loss M 𝛟 N 𝛉 X Y F G Loss M 𝛟 N 𝛉 I X Y F G Loss N 𝛉 T ~ Ⲧ T ~ Ⲧ (a) (b) (c)

MAJOR CONTRIBUTIONS

• We observe that large-batch training of dual encoding models can be simulated on decentralized datasets by using loss functions based on encoding statistics aggregated across clients, without sharing individual samples. • Building on this insight, we present Distributed Cross Correlation Optimization (DCCO), a novel approach for training dual encoding models on decentralized datasets composed of a large number of small, non-IID client datasets. • We prove that when we perform one step of local training in each federated training round, one round of DCCO training is equivalent to one step of centralized training on a large batch composed of all samples across all clients participating in the federated round. • We evaluate the proposed DCCO approach on CIFAR-100 and dermatology datasets, and show that it outperforms FedAvg variants of contrastive and CCO training by a significant margin. The method also significantly outperforms supervised training from scratch, demonstrating its effectiveness for decentralized self-supervised learning.

2. PROBLEMS WITH EXISTING APPROACHES

Contrastive loss functions Contrastive losses explicitly maximize similarity between two encodings of a data sample while pushing encodings of different samples apart. This is highly effective when each sample is contrasted against a large set of diverse samples (Chen et al., 2020a; b; He et al., 2020; Radford et al., 2021; Jia et al., 2021) . Contrastive learning approaches can be extended to federated settings by combining within-client contrastive training with the FedAvg strategy of McMahan et al. (2017) . But this reduces performance as each sample is contrasted against a small set of within-client samples, which may be relatively similar.



Figure 1: Dual encoding models. Two inputs generated using random data augmentations (T ∼ T ) are encoded by (a) the same encoder network or (b) two different networks. In (c) aligned inputs from two different modalities are encoded by two different modality-specific networks.

