FEDERATED LEARNING FROM SMALL DATASETS

Abstract

Federated learning allows multiple parties to collaboratively train a joint model without having to share any local data. It enables applications of machine learning in settings where data is inherently distributed and undisclosable, such as in the medical domain. Joint training is usually achieved by aggregating local models. When local datasets are small, locally trained models can vary greatly from a globally good model. Bad local models can arbitrarily deteriorate the aggregate model quality, causing federating learning to fail in these settings. We propose a novel approach that avoids this problem by interleaving model aggregation and permutation steps. During a permutation step we redistribute local models across clients through the server, while preserving data privacy, to allow each local model to train on a daisy chain of local datasets. This enables successful training in data-sparse domains. Combined with model aggregation, this approach enables effective learning even if the local datasets are extremely small, while retaining the privacy benefits of federated learning.

1. INTRODUCTION

How can we learn high quality models when data is inherently distributed across sites and cannot be shared or pooled? In federated learning, the solution is to iteratively train models locally at each site and share these models with the server to be aggregated to a global model. As only models are shared, data usually remains undisclosed. This process, however, requires sufficient data to be available at each site in order for the locally trained models to achieve a minimum quality-even a single bad model can render aggregation arbitrarily bad (Shamir and Srebro, 2014) . In many relevant applications this requirement is not met: In healthcare settings we often have as little as a few dozens of samples (Granlund et al., 2020; Su et al., 2021; Painter et al., 2020) . Also in domains where deep learning is generally regarded as highly successful, such as natural language processing and object detection, applications often suffer from a lack of data (Liu et al., 2020; Kang et al., 2019) . To tackle this problem, we propose a new building block called daisy-chaining for federated learning in which models are trained on one local dataset after another, much like a daisy chain. In a nutshell, at each client a model is trained locally, sent to the server, and then-instead of aggregating local models-sent to a random other client as is (see Fig. 1 ). This way, each local model is exposed to a daisy chain of clients and their local datasets. This allows us to learn from small, distributed datasets simply by consecutively training the model with the data available at each site. Daisy-chaining alone, however, violates privacy, since a client can infer from a model upon the data of the client it received it from (Shokri et al., 2017) . Moreover, performing daisy-chaining naively would lead to overfitting which can cause learning to diverge (Haddadpour and Mahdavi, 2019) . In this paper, we propose to combine daisy-chaining of local datasets with aggregation of models, both orchestrated by the server, and term this method federated daisy-chaining (FEDDC). We show that our simple, yet effective approach maintains privacy of local datasets, while it provably converges and guarantees improvement of model quality in convex problems with a suitable aggregation method. Formally, we show convergence for FEDDC on non-convex problems. We then show for convex problems that FEDDC succeeds on small datasets where standard federated learning fails. For that, we analyze FEDDC combined with aggregation via the Radon point from a PAC-learning perspective. We substantiate this theoretical analysis for convex problems by showing that FEDDC in practice matches the accuracy of a model trained on the full data of the SUSY binary classification dataset with only 2 samples per client, outperforming standard federated learning by a wide margin. For non-convex settings, we provide an extensive empirical evaluation, showing that FEDDC outperforms naive daisy-chaining, vanilla federated learning FEDAVG (McMahan et al., 2017) , FEDPROX (Li et al., 2020a), FEDADAGRAD, FEDADAM, and FEDYOGI (Reddi et al., 2020) on low-sample CIFAR10 (Krizhevsky, 2009), including non-iid settings, and, more importantly, on two real-world medical imaging datasets. Not only does FEDDC provide a wide margin of improvement over existing federated methods, but it comes close to the performance of a gold-standard (centralized) neural network of the same architecture trained on the pooled data. To achieve that, it requires a small communication overhead compared to standard federated learning for the additional daisy-chaining rounds. As often found in healthcare, we consider a cross-SILO scenario where such small communication overhead is negligible. Moreover we show that with equal communication, standard federated averaging still underperforms in our considered settings. In summary, our contributions are (i) FEDDC, a novel approach to federated learning from small datasets via a combination of model permutations across clients and aggregation, (ii) a formal proof of convergence for FEDDC, (iii) a theoretical guarantee that FEDDC improves models in terms of , δ-guarantees which standard federated learning can not, (iv) a discussion of the privacy aspects and mitigations suitable for FEDDC, including an empirical evaluation of differentially private FEDDC, and (v) an extensive set of experiments showing that FEDDC substantially improves model quality on small datasets compared to standard federated learning approaches.

2. RELATED WORK

Learning from small datasets is a well studied problem in machine learning. In the literature, we find both general solutions, such as using simpler models and transfer learning (Torrey and Shavlik, 2010), and more specialized ones, such as data augmentation (Ibrahim et al., 2021) and few-shot learning (Vinyals et al., 2016; Prabhu et al., 2019) . In our scenario overall data is abundant, but the problem is that data is distributed into small local datasets at each site, which we are not allowed to pool. Hao et al. (2021) propose local data augmentation for federated learning, but their method requires a sufficient quality of the local model for augmentation which is the opposite of the scenario we are considering. Huang et al. (2021) provide generalization bounds for federated averaging via the NTK-framework, but requires one-layer infinite-width NNs and infinitesimal learning rates. Federated learning and its variants have been shown to learn from incomplete local data sources, e.g., non-iid label distributions (Li et al., 2020a; Wang et al., 2019) and differing feature distributions (Li et al., 2020b; Reisizadeh et al., 2020a) , but fail in case of large gradient diversity (Haddadpour and Mahdavi, 2019) and strongly dissimilar label distribution (Marfoq et al., 2021) . For small



Figure 1: Federated learning settings. A standard federated learning setting with training of local models at clients (middle) with aggregation phases where models are communicated to the server, aggregated, and sent back to each client (left). We propose to add daisy chaining (right), where local models are sent to the server and then redistributed to a random permutation of clients as is.

