MERGING MODELS PRE-TRAINED ON DIFFERENT FEATURES WITH CONSENSUS GRAPH Anonymous authors Paper under double-blind review

Abstract

Learning an effective global model on private and decentralized datasets has become an increasingly important challenge of machine learning when applied in practice. Federated Learning (FL) has recently emerged as a solution to address this challenge. In particular, the FL clients agree to a common model parameterization in advance, which can then be trained collaboratively via synchronous aggregation of local model updates. However, such a strong requirement of modeling homogeneity and synchronicity across clients makes FL inapplicable to many practical scenarios. For example, in distributed sensing, a network of heterogeneous sensors sample from different data modalities of the same phenomenon. Each sensor thus requires its own specialized model. Local learning therefore happens in isolation but inference still requires merging the local models to achieve consensus. To enable isolated local learning and consensus inference, we investigate a feature fusion approach that extracts local feature representations from local models and incorporates them into a global representation for holistic prediction. We study two key aspects of this feature fusion. First, we use alignment to correspond feature components which are arbitrarily arranged across clients. Next, we learn a consensus graph that captures the high-order interactions among data sources or modalities, which reveals how data with heterogeneous features can be stitched together coherently to achieve a better prediction. The proposed framework is demonstrated on four real-life data sets including power grids and traffic networks.

1. INTRODUCTION

To improve the scalability and practicality of machine learning applications in situations where training data are becoming increasingly decentralized and proprietary, Federated Learning (FL) (McMahan et al., 2017; Yang et al., 2019a; Li et al., 2019; Kairouz et al., 2019) has been proposed as a new model training paradigm that allows data owners to collaboratively train a common model without having to share their private data with others. The FL formalism is therefore poised to resolve the computation bottleneck of model training on a single machine and the risk of privacy violation, in light of recent policies such as the General Data Protection Regulation (Albrecht, 2016) . However, FL requires a strong form of homogeneity and synchronicity among the data owners (clients) that might not be ideal in practice. First, it requires all clients to agree in advance to a common model architecture and parameterization. Second, it requires clients to synchronously communicate their model updates to a common server, which assembles the local updates into a global learning feedback. This is rather restrictive in cases where different clients draw observations from a different data modality of the phenomenon being modeled. It leads to heterogeneous data complexities across clients, which in turn requires customized forms of modeling. Otherwise, enforcing a common model with high complexity might not be affordable to clients with low compute capacity; and vice versa, switching to a model with low complexity might result in the failure to unlock important inferential insights from data modalities. A variant of FL (Hardy et al., 2017; Hu et al., 2019; Chen et al., 2020) , named vertical FL, has been proposed to address the first challenge, which embraces the concept of vertically partitioned data. This concept is figuratively named through cutting the data matrix vertically along the feature axis, rather than the data axis. Existing approaches generally maintain separate local model parameters distributed across clients and global parameters on a central server. All parameters are then learned jointly, causing however a practical drawback: Coordination overhead among clients and the central server, such as engineering protocols that enable multiple rounds of communication (i.e., synchronicity) and coordination effort (i.e., homogeneity) to converge on universal choices of models and training algorithms, would be required, which can be practically expensive depending on the scale of the application. To mitigate both constraints on homogeneity and synchronicityfoot_0 satisfactorily, we ask the following question and subsequently develop an answer to it: Can we separate global consensus prediction from local model training? As shown later in our experiments, we will address this question in a real-world context of the national electricity grid, over which thousands of phasor measurement units (PMUs) were deployed to monitor the grid condition and data were recorded in real-time by each PMU (Smartgrid.gov). PMU measurements, as time series data, are owned by several parties, each of which may employ different technologies leading to heterogeneous recordings under varying sampling frequencies and measured attributes. These data may be used to train machine learning models that identify grid events (e.g., fault, oscillation, and generator trip). Such an event detection system relies on collective series measurements at the same time window but distributed across owners. Using VFL to build a common model on such decentralized and heterogeneous data is plausible but not practical, because of a lack of autonomy that facilitates coordination among the owners. To resolve the challenge, we instead introduce a feature fusion perspective to this setting, which aims to minimize coordination among clients and maximize their autonomy via a local-global model framework. Therein, each client trains a customized local model with its data modalities. The training is independent and incurs no coordination. Once trained, local feature representations of each client can then be extracted from the penultimate layer of the corresponding local models. Then, a central server collects and aggregates these representations into a more holistic global representation, used to train a model for global inference. There are two technical challenges that need to be addressed to substantiate the envisioned framework. C1. There is an ambiguity regarding the correspondence between components of local feature representations across different clients. This ambiguity arises because local models were trained separately in isolation and there is no mechanism to enforce that their induced feature dimensions would be aligned. As a matter of fact, it is possible to permute the induced feature dimensions without changing the prediction outcome. Thus, if two models are trained separately, they might end up looking at the same feature space but with permuted dimensions. C2. There are innate local interactions among subsets of clients that need to be accounted for. Naively concatenating or averaging the local feature representations accounts for the global interaction but ignores such local interactions, which are important to boost the accuracy of global prediction as shown later in our experiments. To address C1, note that the feature dimension alignment problem is discrete in nature; furthermore, there is no direct feedback to optimize for such alignment. To sidestep this challenge, we develop a neuralized alignment layer whose parameters are differentiable and can therefore be part of a larger network, including the feature aggregation and prediction layers, which can be trained end-to-end via gradient back-propagation (Section 4). To address C2, we employ graph neural networks as the global inference model, where the graph corresponds to the explicit or implicit relational structure of the data owners. As such a graph might not be given in advance, we treat the combinatorial graph structure as a random variable of a product of Bernoulli distributions whose (differentiable) parameters can also be optimized via gradient-based approach (Section 5). The technical contributions of this work are summarized below. 1. We formalize a feature fusion perspective for distributed learning, in settings where data is vertically partitioned. This is an alternative view to VFL but as elaborated above, is more applicable when iterative training synchronicity is not possible among clients (Section 2). 2. We formulate a federated feature fusion (F 3 ) framework that consists of a network of pre-trained local models and a central model that collects and fuses the local feature representations (induced from these pre-trained models) to generate a global model with better predictive performance (Section 3). This is achieved via addressing C1 (Section 4) and C2 (Section 5) above.

3.

We demonstrate experiments with four real-life data sets, including power grids and traffic networks, and show the effectiveness of the proposed framework (Section 6).

2. PROBLEM SETTING AND RELATED WORK

Federated Feature Fusion (F 3 ) is a new but more practical setup for VFL (Hu et al., 2019; Chen et al., 2020) ; it aims to enable collaboration between data owners that possess private access to different sets of features describing the same set of training data points. However, unlike VFL which require clients to synchronize their training processes (Yang et al., 2019b; Li et al., 2021; Fu et al., 2021; Cheng et al., 2021; Hu et al., 2019; Diao et al., 2021) in multiple iterations of communication, F 3 allows data owners to train their own local models in isolation and only requires one round of communication in which local feature representations induced from the heterogeneously pre-trained local models are shared with a trusted server for feature fusion. Remark. Previously, similar ideas on extending federated learning to accommodate clients with heterogeneous models (Tan et al., 2022b; a; Lin et al., 2020; Chen et al., 2022) has been proposed but are still restricted to horizontal settings: Local models still need to operate on the same feature space and cannot be trained in isolation which consequently require multiple rounds of communication and potentially incur extra coordination overhead. Thus, to emphasize on the novelty of our setting and solution significance, we further review and discuss the formulation of VFL and F 3 below, which argues with concrete, real-life examples why the F 3 setting is more practical and how this practicality would entail significant technical challenges that necessitate new solutions in Sections 4 and 5. Federated Learning with Vertically Partitioned Data. From a data perspective, the decentralized nature of data in VFL is a transposition to that of the traditional horizontal federated learning (HFL) (McMahan et al., 2017) . Instead of owning the same set of features for different sets of data points as in HFL, the data owners in VFL now own different sets of features for the same set of data points; and they share a common label set of these data points. From the existing literature, two lines of work are noted. One takes the data matrix literally -by assuming tabular data and studying linear models -where model parameters have natural correspondence to the data parts (Hardy et al., 2017; Nock et al., 2018; Heinze et al., 2014; 2016) . Often, these approaches are hard to generalize to complex data with many owners. Another line of work advocates the use of models with modular structure in which separate parts of the model are responsible to locally aggregate different sets of local features owned by different owners; and a global parameterization is used to combine these local features. This is similar in spirit to F 3 but require clients to synchronize the training processes of their assigned model parts, which incurs expensive communication and creates dependence among the clients (Hu et al., 2019; Chen et al., 2020). 2 Mathematically, for each datum x k with label y k , let x i k be the feature set of the datum that the i-th owner possesses. That is, x k = (x 1 k , x 2 k , . . . , x n k ) with n data owners. VFL computes: minimize w,θ L (w, θ) ≜ 1 m m k=1 ℓ g ϕ 1 x 1 k ; θ 1 , ϕ 2 x 2 k ; θ 2 , . . . , ϕ n (x n k ; θ n ) ; w , y k (1) where each ϕ i (x i k ; θ i ) is a (learnable) local embedding of x i k parameterized by a separate parameter vector θ i owned by the i-th owner, g(ϕ 1 , ϕ 2 , . . . , ϕ n ; w) is an aggregation function parameterized by w and ℓ is a prediction loss, e.g. the cross-entropy loss for classification or ℓ 2 loss for regression. The loss in Eq. ( 1) is averaged over all training data points x 1 , x 2 , . . . , x m . Federated Feature Fusion. The setting of F 3 is similar to VFL, except that the data owners share neither data nor models with each other to ensure a higher degree of privacy compliance, which is often the more practical setting in industry -see the example on power grid at the end of this section. For this reason, the VFL minimization task in Eq. (1) above is changed to minimize w L (w) ≜ 1 m m k=1 ℓ g h 1 k , h 2 k , . . . , h n k ; w , y k (2) where h i k = ϕ * i (x i k ) with ϕ * i = arg min ϕi ℓ i ϕ i (x i k ) , y k which characterizes the locally optimal feature representation obtained in isolation by the i-th owner. As such, Eq. ( 2) only requires one round of communication where {h i k } k,i are communicated to a trusted server. Prior to that, each data owner can freely learn their own feature representation model ϕ i (x i k ) with different parameterization and architecture, catering towards their own compute capacities and data representation. This avoids forcing the data owners to participate in a joint training scheme which often requires expensive coordination and is not practical. However, in exchange for this practicality, two key challenges arise. First, as local models are separately trained, the correspondence between components of induced feature representations across local models become ambiguous since there is no mechanism to enforce their alignment. Second, for the same reason, there are potential innate local interactions among subsets of clients and a naive concatenation or averaging of their corresponding feature presentations will likely ignore such interactions, resulting in decreasing performance. These correspond to high-level challenge C1 and C2 in Section 1 which will be addressed in Sections 4 and 5 as our key technical contributions. Data Example. Let us consider the power grid monitoring task as an example. Figure 1 pictorially illustrates PMU measurements distributed across data owners. A panel of time series corresponds to a specific time window and the series collectively represent one data point, which the event detection system classifies. In this simplified illustration, each data owner possesses one series recorded by one PMU; but in practice they may own different amounts of PMUs (and thus series). Moreover, the series may differ in length because of varying sampling frequencies; and the series are multivariate with possibly different number of variates. All these variations contribute to data heterogeneity, which necessitates the construction of separate local models. Note that if an event does not cascade over the entire grid, some local models may report event whereas others report normal, resulting in conflicting opinions. A consensus global model is responsible to resolve the conflict. Additionally, missing data may occur.

3. FEDERATED FEATURE FUSION FRAMEWORK

As detailed above, the proposed framework for Federated Feature Fusion consists of local models ϕ i and a global feature fusion model g, such that their composition minimizes the loss in Eq. (2). Each data owner i possesses a local model trained with its data, independently of other owners. This way, no data sharing is invoked and privacy is of minimal concern. However, because the local models lack a global vision and may be conflicting, a central (global) model is key to coordinating the local opinions for final prediction. To maintain autonomy, local models are frozen once pre-trained and will not join the training of the global model. Data owners send local data representations to a centralized server for global model training (and inference). In other words, the global model queries neither the raw data nor the local models from data owners. As long as owners agree to send the less decipherable representations to the central server, global inference can be made. Local Models. We treat a neural network except the final output layer as a feature extractor, which produces the representation h i k of an input fragment x i k ; and treat for simplicity the output layer as  g i x i k ≜ softmax W i • h i k + b i where h i k = ϕ i x i k . Hereafter, we will interchangeably use representation, embedding, and latent vector to mean h i k . These h i k 's are assumed to have the same shape across i, although x i k can have different shapes and the embedding function can have different architectures to cope with data heterogeneity. A simple example of the embedding function is a fully connected layer h i k = ReLU(U i • x i k + c i ) ; but an arbitrarily complex function is also applicable. Global Model. The global model g melds together all local representations to generate a prediction: y k ≃ y k ≜ g h 1 k , h 2 k , . . . , h n k ; w . ( ) which is parameterized by w. For example, the parameterization w = {W 0 , W 1 , b 0 , b 1 } can characterize a fully connected layer followed by mean pooling and another fully connected layer: g h 1 k , h 2 k , . . . , h n k ; w = softmax W 1 • 1 n n i=1 ReLU W 0 • h i k + b 0 + b 1 . (5) Thus, given a particular parameterization w, we can substantiate Eq. ( 4) above and plug it into Eq. ( 2). The optimal value for w can then be achieved by solving the corresponding minimization task therein. However, designing the form of w is highly non-trivial and is in fact tied to the previously mentioned challenges C1 and C2, which we further elaborate below. Challenges. Two considerations are pertinent to the design of w. First, when the latent dimensions have semantic meaning -e.g. when the local models are trained to yield disentangled representations (Higgins et al., 2018) -each latent feature of the local representations may not match, because an arbitrary permutation of the latent dimensions does not change a local model. Second, a naive mean pooling as in ( 5) may miss the interdependencies between local data, leading to a less well performing global model. Such interdependencies naturally occur in the power grid example because of the physics of an electricity network. Hence, in subsequent sections, we use latent alignment to address the first problem and graph neural network to address the second one. Incorporating these two components, we show the full proposed framework in Figure 2 and Algorithm 1. We now discuss the solutions to these challenges in Sections 4 and 5 below.

4. ALIGNING LOCAL REPRESENTATIONS

For the global model to be meaningful, the feature dimensions of the local representations h i k should be aligned under the same feature space. For example, in (5), all h i k 's multiply the same weight matrix W 0 ; in other words, each element of h i k corresponds to one input neuron of the initial fully connected layer. Permutations of the elements will destroy the correspondence. That is, even if Algorithm 1 Federated Feature Fusion (F 3 ) 1: function TRAINING({(x i k , y k ) m k=1 } n i=1 ) 2: Each data owner i trains a local model g i with its local data part (x i k , y k ) m k=1 .

3:

Each data owner i sends its local data representations {h i k } m k=1 to the central server -Eq. ( 3) 4: Central server learns y k = g(P 1 h 1 k , P 2 h 2 , . . . , P n h n ) via Eq. ( 7). 5: Here, the global model is (8), where the loss is taken over the distribution of A.

6:

Entries of A are sampled using (9).

7:

Each alignment matrix P i is a learnable arbitrary parameter matrix. 8: end function 9: function INFERENCE(x 1 , . . . , x m where x k = (x 1 k , . . . , x n k )) 10: Each data owner i evaluates its local model with x i k to obtain h i k and sends to server. 11: Server evaluates takes {h i k } n i=1 as input and produces prediction via Eq. ( 5). 12: end function the local models are fixed, the arbitrary arrangement of the feature dimensions of the latent vectors causes ambiguity of what an optimal global model can be built. Mathematically, let us use a vector p to denote the index (column) permutation of a vector (matrix). Then, the ith local model (3) can be equivalently written as g i x i k ≜ softmax W i [:, p i ] • h i k [p i ] + b i [p i ] where h i k ≜ ϕ i x i k , for any permutation p i as long as the embedding function is able to produce a permuted h i k [p i ] under the same input x i k . Such a requirement can be easily satisfied if the embedding function is a fully connected layer such as h [p] = ReLU(W[p, :] • x + b[p]). In fact, it is satisfied by most neural networks as well. In Appendix B, we give another example: the GRU (Cho et al., 2014) . Hence, we propose to align the feature dimensions across all local vectors h i k to disambiguate the ambiguity. This proposal amounts to modifying the global model (4) to the following: y k ≃ y k ≜ g P 1 • h 1 k , P 2 • h 2 k , . . . , P n • h n k , where P i is an alignment matrix for each data owner i, implementing the (manual) index or column permutation above in linear algebra. We can then treat each P i as a free parameter matrix to optimize. It may be square or rectangle, the latter case indicating a change of the number of features. We also show an alternative hard alignment by parametrizing P i a permutation matrix in Appendix H.

5. LEARNING A CONSENSUS GRAPH

The example global model ( 5) performs a naive averaging for the local representations. Since data owners are often interconnected, a more expressive model exploits their relational interactions to improve inference (Battaglia et al., 2018) . To this end, we propose to use a graph neural network (GNN) (Zhang et al., 2020; Wu et al., 2021) to process the latent representations. A. Modeling Consensus Graph via GCN with Latent Graph. Many GNNs are applicable; we focus on GCN (Kipf & Welling, 2017) for its simplicity. Let A be the graph adjacency matrix and let H k be the matrix of aligned local representations: H k ≜    -(P 1 h 1 k ) ⊤ - . . . -(P n h n k ) ⊤ -    . Traditionally, GCN was designed for node classification so we modify it slightly for our purpose, y k ≃ y k ≜ softmax 1 n 1 ⊤ A • ReLU AH k W 0 • W 1 , ( ) where A is a normalization of A -see (Kipf & Welling, 2017) for details -and W 0 and W 1 are weight matrices. The modification is the inclusion of 1 n 1 T as pooling before output. Modulo this modification, the formula ( 8) is a standard one used in the literature, with the bias terms omitted. It is interesting to note the equivalence of GCN ( 8) and the graph-agnostic model (5) when A is replaced by the identity matrix (omitting bias terms). In GCN, A corresponds to the consensus graph among local owners as graph nodes. If such a graph is not present, it is possible to learn one such that (8) still outperforms (5). In this case, we treat A as a random variable of the matrix Bernoulli distribution, where the success probabilities are free parameters to learn. Formally, the elements A ij are independent and each follows Ber(θ ij ), where θ ij denotes the corresponding probability (Kipf et al., 2018; Shang et al., 2021) . Then, the global model g has W 0 , W 1 , the P i 's, as well as θ, as parameters. Following Franceschi et al. (2019) ; Shang et al. (2021) , we formulate the training loss as an expectation over A's distribution and draws a sample A to obtain an unbiased estimate of the loss as well as the gradient. B. Differentiable Graph Sampling via Re-parameterization. However, the central challenge of this approach is that the sample A ij is not differentiable with respect to the corresponding Bernoulli bias θ ij , which in turn makes the training loss non-differentiable with respect to θ. To sidestep this difficulty, we propose the following reparameterization, which presents a learnable (differentiable) transformation of a sample drawn from a continuous distribution to a discrete Bernoulli sample. This transformation is detailed in Definition 1 below, which is followed by Theorem 1 showing the distributional convergence of this transformation to the desired Bernoulli distribution. Definition 1. Let F be the CDF of an arbitrary continuous probability distribution. Sample s from this reference distribution and let z ≜ sigmoid 1 τ F -1 (θ) -s , τ > 0. We call this the ICDF re-parameterization which is named after the use of inverse cumulative F -1 . Theorem 1. For all τ > 0, θ ∈ (0, 1) and t ∈ [0, 1], if the distribution with CDF F is finitely supported on [a, b], then Pr(z ≤ t) =    0 if t < sigmoid((F -1 (θ) -b)/τ ), 1 if t > sigmoid((F -1 (θ) -a)/τ ), 1 -F (F -1 (θ) + τ log(t -1 -1)) otherwise. ( ) On the other hand, if the distribution is not finitely supported (i.e., a = -∞ and/or b = +∞), Eq. ( 10) still holds because either (or both) of the first two cases will not be invoked. As a consequence, the distribution of z converges to Ber(θ) as τ → 0. Discussion. We note that an alternative to the above can be achieved via using the Gumbel softmax reparameterization (Jang et al., 2017; Maddison et al., 2017) which also features a differentiable relaxation of the categorical distribution (in this case, the Bernoulli distribution) that approximates it asymptotically. However, in order to obtain one Bernoulli sample, the Gumbel trick requires to sample the Gumbel distribution twice. Instead, our proposed reparameterization only requires sampling from the reference distribution only once. We also show that the ICDF re-parameterization converges as fast as the Gumbel softmax. Both approaches have asymptotic convergence rate on the order of O(τ 2 ) as shown in Section D. Empirically, we also show that ICDF induces marginally better performance than Gumbel softmax. This is why we prefer ICDF to Gumbel in our work.

6. EXPERIMENTS

In this section, we demonstrate comprehensive experiments to show that federated feature fusion (F 3 ) can be effectively conducted by using the proposed techniques in Sections 4 and 5. Datasets. We use four real-life, time series datasets. Two are PMU data collected from multiple data owners of the U.S. power grid. For proof of concept, we smooth out heterogeneity and prepare homogeneous data sets. Such a pre-processing is sufficient to test the proposed techniques under minimal impact of the complication by the otherwise diverse local models. Since the PMU data sets are proprietary, we also use two public, traffic data sets (Li et al., 2018) for experimentation. A summary of these data sets is given in Table 1 and the processing details are given in the supplement. From Table 2 , we observe that baselines (A-D), either lacking necessary localized models or a holistic global model, perform significantly worse than the other baselines (including our F 3 variants, the ensemble via mean pooling baseline and vertical federated learning). On the other hand, baselines (E-H) perform better than (A-D) but lacking a proper alignment of local models or imposing a strong form of homogenization among local models to sidestep alignment, they are expectedly outperformed by baseline (K) that performs alignment. We also compare with two variants of our final model K in the vertical federated learning setting. L uses the same local and global models as K but allows gradients to be sent back to local clients, thus local models can be updated. It achieves similar performance as K but leads to much more communication cost with multiple rounds of gradient messages. M assumes no local pretrained models and all local and global models are trained jointly from scratch. Its performance is much worse than K and L -explains the merit of pretrained local models. Another typical VFL baseline (G) with pretrained local models and a simple concatenation based global model is also inferior. Impact of Learning Graph. Our next set of experiments, as outlined in Table 3 , demonstrate the impact of learning a graph that characterizes the innate local interactions among subsets of clients, following our challenge statement C2 in the introduction, on both alignment and non-alignment baselines. This provides ablation studies on the isolated impact of having a specific graph learning component. In particular, for each alignment setting, we demonstrate the impact on performance with (a) not using a graph; (b) using a graph given by the domain experts; (c) learning the graph structure using a k-NN baseline (Fatemi et al., 2021) ; and (d) learning the graph structure using ICDF. The reported numbers suggest that regardless of whether the model performs alignment, graph learning always improves performance. Remark. The k-NN baseline (k = 10) is implemented following the description in (Fatemi et al., 2021) . Specifically, during training, we generate a local graph for each batch for node features X via a symmetrization of Ã = k-NN(MLP(X)) which (1) feeds the node features through a MLP neural block; and (2) draws an edge between each node and its k nearest neighbors where the neighborhood is defined using the cosine similarity on the space of MLP-projected feature vectors.

7. CONCLUSIONS

In this paper, we study federated feature fusion, which presents a less addressed scenario of federated learning where data owners or clients need to customize their own local models to accommodate different sets of (federated) features. Unlike federated learning, the clients need to learn their own model separately in isolation and only communicate their local feature representations afterwards. We motivate the practicality of federated feature fusion scheme with a power grid example and propose a local-global model framework for it. Two important components of the framework are the alignment of the data representations produced by local models and the learning of the global model by using a graph neural network. Comprehensive experiments suggest the feasibility of federated feature fusion and the effectiveness of the framework.

A RELATED WORK

The concept of federated learning was first coined by McMahan et al. (2017) and it has attracted surging interests since. A form of distributed optimization, federated learning is faced with data challenges beyond conventional assumptions and puts communication efficiency and data privacy as primary concerns. Recent surveys (Yang et al., 2019a; Li et al., 2019; Kairouz et al., 2019) comprehensively study the subject, review systems and infrastructures, and suggest open problems. The typical setting of federated learning is that data sets across owners share the same feature space but differ in samples. Besides this horizontal partitioning of the data matrix, a vertical partitioning was studied by Hardy et al. (2017) ; Nock et al. (2018) ; Heinze et al. (2014; 2016) , wherein features are split across owners instead. This setting bears resemblance to our federated feature fusion scenario, but a crucial distinction is that existing methods for vertical federated learning all perform joint training. In the referenced work, to preserve privacy, encrypted data or randomly projected data are communicated among data owners as well as a central coordinator. Such an approach incurs demanding communication for many owners. Recently, Chen et al. ( 2020) study a different model, whose parameters are distributed among owners as well as a central server. The part of the model corresponding to an owner bears resemblance to our local models; but they are not local models since they are not independently trained by using local data. Another work along a similar direction is conducted by Hu et al. (2019) , but the global model has no parameters; it is merely a sum of the local outputs followed by activation (e.g., sigmoid for classification). Our framework learns parameter matrices to align local representations. Such alignments similarly appear in model fusion, where a number of models are fused together into a common model through aligning model parameters (Yurochkin et al., 2019a) . In the context of deep learning, if the neural networks come from the same model family, their weights can be matched layerwise, even if the numbers of weights are different (Yurochkin et al., 2019b; Wang et al., 2020) . The referenced work treats the problem as a bipartite graph matching, where the cost matrix is inferred from maximum a posteriori estimation. Then, the Hungarian algorithm (Kuhn, 1955) is applied to find the matching. In our work, instead we treat the permutation alignment as a differentiable parameterization with the help of Sinkhorn-Knopp (Sinkhorn & Knopp, 1967; Mena et al., 2018; Emami & Ranka, 2018) , so that it can be learned end-to-end with other parameters of the global model. (2021) , wherein a graph structure is simultaneously learned together with the GNN parameters. The Gumbel trick (Jang et al., 2017; Maddison et al., 2017) is frequently used for differentiable parameterization, but in this paper we study a more economic alternative parameterization ICDF.

B PERMUTATION AMBIGUITY EXAMPLE FOR GRU

In Section 3, we discuss that one can arbitrarily permute the latent representations while keeping a local model fixed. Here, we give another example -the GRU. Let x = {x 1 , x 2 , . . . , x T } be an input sequence. The embedding function h = embedding(x) implemented as a GRU reads: 1: function h = GRU({x t } T t=1, ) 2: h 0 = 0 3: for t = 1, . . . , T do 4: z t = sigmoid(W z x t + U z h t-1 + b z ) 5: r t = sigmoid(W r x t + U r h t-1 + b r ) 6: n t = tanh(W n x t + U n (r t ⊙ h t-1 ) + b n ) 7: h t = (1 -z t ) ⊙ h t-1 + z t ⊙ n t 8: end for 9: return h = h T 10: end function

C.2 PROOF OF THEOREM 1

We first consider the case when the distribution with cdf F is finitely supported on [a, b] . Through simple algebraic manipulation, we obtain that z ≤ t is equivalent to s ≥ M where M := F -1 (θ) + τ log(t -1 -1). If t < sigmoid((F -1 (θ) -b)/τ ), we see that M > b and thus such s can never occur. Similarly, if t > sigmoid((F -1 (θ) -a)/τ ), we see that M < a, which indicates that s ≥ M always happens. Otherwise, when t is within the two extremes, the probability that s ≥ M happens is 1 -F (M ), concluding the proof of (10). The statement of the theorem regarding the case when the distribution is not finitely supported is obviously true. To show that the distribution of z converges to Ber(θ), let us first consider the scenario when the distribution with cdf F is finitely supported. The cdf of z (see ( 10)) is always continuous but it has three segments connected by two joints: t 1 = sigmoid((F -1 (θ) -b)/τ ) and t 2 = sigmoid((F -1 (θ) -a)/τ ). When τ → 0, the joint t 1 → 0 and the joint t 2 → 1 and thus the middle segment has a wider and wider support converging to [0, 1]. Hence, it suffices to consider only the middle segment. Further, with an analogous argument for other scenarios, it is also true that it suffices to consider only the third case of (10). In this case, for any fixed t < 1 and when τ → 0, we have τ log(t -1 -1) → 0 and thus Pr(z ≤ t) → 1 -F (F -1 (θ)) = 1 -θ. Meanwhile, we cannot push τ → 1 because then the limit of τ log(t -1 -1) is undefined. However, we know by definition that Pr(z ≤ 1) = 1. Hence, the continuous distribution of z converges to a degenerate distribution Pr(z < 1) = 1 -θ and Pr(z = 1) = 1. This is the CDF of Ber(θ).

D TUNING GUIDANCE FOR TEMPERATURE τ

Our tuning guidance for the temperature τ is motivated from a asymptotic convergence comparison between ICDF and Gumbel re-parameterization, which is featured in the theorem below. Theorem 3. When τ is small, Bias(y 1 ) = 1 6 τ 2 π 2 θ(1 -θ)(1 -2θ) + O(τ 4 ), Bias(z) = 1 6 τ 2 π 2 F ′′ (F -1 (θ)) + O(τ 4 ). Moreover, when F is the CDF of a normal variable ∼ N(0, σ 2 ), Bias(z) = - 1 6σ 2 τ 2 π 3 2 erf -1 (2θ -1) e -(erf -1 (2θ-1)) 2 + O(τ 4 ). Its formal proof is detailed later in Appendix E. Theorem 3 suggests that the ICDF method converges equally fast as does the Gumbel trick -both on the order of O(τ 2 ). On the other hand, the biases depend on θ. Thus, one cannot set temperatures τ independently of the desired probability θ to equate the two biases. In practice, τ is a tunable hyper-parameter and a guidance on the tuning range is therefore necessary. To begin, we use a subscript to distinguish the two temperatures -τ g for the Gumbel trick and τ i for the ICDF method -and write, based on ( 14) and ( 16) and ignoring the high order terms, Bias(y 1 ) Bias(z) ≃ τ 2 g σ 2 τ 2 i r(θ) where r(θ) = √ πθ(1 -θ)(2θ -1) erf -1 (2θ -1)e -(erf -1 (2θ-1)) 2 . Note that r(θ) is symmetric around θ = 1 2 , is concave, attains maximum 1 2 when θ = 1 2 , and attains minimum 0 when θ = 0, 1. Hence, if τ g = τ i and σ = √ 2, the bias of the Gumbel trick is (approximately) smaller than that of the icdf method. On the other hand, for a σ > √ 2, there exist θ 1 < θ 2 such that σ -2 = r( θ 1 ) = r( θ 2 ) and that Bias(y 1 ) ⪆ Bias(z), whenever θ ∈ [ θ 1 , θ 2 ]. For example, when σ ≈ 2.5, on the interval θ ∈ [0.01, 0.99], the bias of the Gumbel trick is (approximately) greater than that of the icdf method. Based on the foregoing, a practical guide is to use the same tuning range of τ for the ICDF method as for the Gumbel trick. A small change of σ (e.g., √ 2 versus 2.5) will entirely flip the landscape of the bias comparison between the two methods. Because the tuning range is much wider than the change of σ, for simplicity it suffices to fix σ = 1.

E PROOF OF THEOREM 3 AND ADDITIONAL RESULTS

By the definition of bias, we have Bias(x) = E[x] -θ where E[x] = 1 0 t d Pr(x ≤ t) = 1 - 1 0 Pr(x ≤ t) dt. Therefore, for Gumbel softmax, Bias(y 1 ) = 1 -θ - 1 0 t τ (1 -θ) t τ (1 -θ) + (1 -t) τ θ dt, and for icdf with any F , Bias(z) = 1 0 F (F -1 (θ) + τ log(t -1 -1)) dt -θ. We now prove Theorem 3 in a few parts. Proof of (15). Let s = F -1 (θ) and perform a change of variable m = log(t -1 -1). Then, Bias(z) = 1 0 [F (s + τ m) -F (s)] dt = ∞ -∞ [F (s + τ m) -F (s)] e m (1 + e m ) 2 dm. We perform Taylor expansion of F around s and obtain F (s + τ m) -F (s) = ∞ n=1 F (n) (s) n! τ n m n . Therefore, Bias(z) = ∞ n=1 F (n) (s) n! τ n ∞ -∞ m n e m (1 + e m ) 2 dm Each integral term is finite and the odd terms vanish because the integrands are odd functions. Thus, for small τ , we are left with Bias(z) = F ′′ (s) 2 τ 2 ∞ -∞ m 2 e m (1 + e m ) 2 dm + O(τ 4 ). The definite integral evaluates to π 2 3 ; we therefore conclude the proof. Proof of (16). Equation ( 16) is straightforward by substuting F ′′ (s) = - s σ 3 √ 2π e -s 2 2σ 2 = - erf -1 (2θ -1) σ 2 √ π e -(erf -1 (2θ-1)) 2 . into (15). Proof of (14). To simplify notation, let β = θ/(1 -θ) and perform a change of variable m = log(t -1 -1). Then,  1 0 t τ (1 -θ) t τ (1 -θ) + (1 -t) τ θ dt = 1 0 dt 1 + βe mτ = ∞ -∞ 1 1 + βe mτ h(τ, m) = ∞ n=0 h (n) (0, m) n! τ n . Therefore, 1 0 t τ (1 -θ) t τ (1 -θ) + (1 -t) τ θ dt = ∞ n=0 τ n n! ∞ -∞ h (n) (0, m) e m (1 + e m ) 2 dm. In a moment, we will show that for all n, h (n) (0, m) = C n m n where C n is independent of m. ( ) Suppose that (17) holds. Then, each integral term is finite and the odd terms vanish, because the integrands are odd functions. Therefore, for small τ , we are left with 1 0 t τ (1 -θ) t τ (1 -θ) + (1 -t) τ θ dt = C 0 ∞ -∞ e m (1 + e m ) 2 dm + C 2 τ 2 2 ∞ -∞ m 2 e m (1 + e m ) 2 dm + O(τ 4 ).

By calculating

C 0 = h(0, m) = [1 + β] -1 = 1 -θ, C 2 = h ′′ (0, m) = -θ(1 -θ)(1 -2θ), ∞ -∞ e m (1 + e m ) 2 dm = 1, ∞ -∞ m 2 e m (1 + e m ) 2 dm = π 2 3 , we conclude that Bias(y 1 ) = τ 2 π 2 θ(1 -θ)(1 -2θ) 6 + O(τ 4 ). It remains to prove (17) . We suppress the argument on m and write g(τ ) = 1 + βe mτ and h(τ ) = g(τ ) -1 . By Faà di Bruno's formula, h (n) (0) = 1 g(τ ) (n) τ =0 = n k=1 (-1) k k! g(0) k+1 • B n,k g ′ (0), g ′′ (0), . . . , g (n-k+1) (0) , where B n,k is the Bell polynomial. Clearly, g(0) = 1 + β and g (r) (0) = βm r for all r > 0. Hence, B n,k is a multiple of m n . Therefore, h (n) (0) is a multiple of m n . E.1 ADDITIONAL RESULT REGARDING THE BIAS Theorem 3 states results for a small temperature τ . The purpose is to understand the limiting behavior of the bias. Here, we give an additional result for any τ > 0. It states that the biases of the two sampling approaches have the same sign. This result is a nontrivial extension of Theorem 3 and requires a different proof technique. Theorem 4. For any τ > 0, Bias(y 1 ) > 0 when θ < 1 2 , Bias(y 1 ) = 0 when θ = 1 2 , Bias(y 1 ) < 0 when θ > 1 2 . Moreover, if F ′ (x) (that is, the pdf) is even and is increasing when x < 0, then Bias(z) > 0 when θ < 1 2 , Bias(z) = 0 when θ = 1 2 , Bias(z) < 0 when θ > 1 2 . ( ) We prove Theorem 4 in two parts. Proof of (18). Consider Bias(y 1 ) = 1 0 g(t, θ) dt where g(t, θ) = 1 -θ - t τ (1 -θ) t τ (1 -θ) + (1 -t) τ θ . With a brute-force calculation, we have g(t, θ) + g(1 -t, θ) = [(1 -t) τ -t τ ] 2 θ(1 -θ)(1 -2θ) [t τ (1 -θ) + (1 -t) τ θ][(1 -t) τ (1 -θ) + t τ θ] . All terms on the right-hand side are positive, except 1 -2θ. Therefore, when θ < 1 2 , g(t, θ) + g(1t, θ) > 0 and hence Bias(y 1 ) = 1 0 g(t, θ) + g(1 -t, θ) 2 dt > 0. The other cases (θ > 1 2 and θ = 1 2 ) are similarly proved. Proof of (19). Consider Bias(z) = 1 0 h(t, θ) dt -θ where h(t, θ) = F (F -1 (θ) + τ log(t -1 -1)). We have h(1 -t, θ) = F (F -1 (θ) -τ log(t -1 -1)). To simplify notation, let F -1 (θ) = s and τ log(t -1 -1) = a. Then, h(t, θ) = F (s + a) and h(1 -t, θ) = F (s -a). Let us first consider the case s < 0 and a > 0. We see that F (s + a) -F (s) = s+a s F ′ (m) dm and F (s) -F (s -a) = s s-a F ′ (m) dm. For any b > 0, if s + b < 0, then by monotonicity, F ′ (s + b) > F ′ (s -b). On the other hand, if s + b ≥ 0, then F ′ (s + b) = F ′ (-s -b) > F ′ (s -b). In both cases, the right integral is always smaller than the left integral. In other words, F (s + a) + F (s -a) > 2F (s). In fact, the above inequality is also established when s < 0 and a < 0. Therefore, whenever s < 0, 1 0 h(t, θ) dt = 1 0 h(t, θ) + h(1 -t, θ) 2 dt > 1 0 F (F -1 (θ)) dt = θ. That is, Bias(z) > 0. Other cases (s = F -1 (θ) > 0 and s = F -1 (θ) = 0) are similarly proved.

E.2 EMPIRICAL COMPARISON BETWEEN GUMBEL AND ICDF RE-PARAMETERIZATION

Extending the last experiment in Section 6, Table 4 summarizes the time and memory consumption during the training of global models on the four data sets. The results indicate that our developed ICDF re-parameterization is more economic than the Gumbel-Softmax approach. The data sets were originally prepared for forecasting tasks and hence no labeling information exists. We adapt the data for classification. Specifically, we split the time series on the hour, forming hourly windows. We label each window as whether or not it corresponds to rush hour. For proof of concept, we specify 07:00-10:00 and 16:00-19:00 as rush hour and the others non-rush hour. We note that in the original data sets, one of the attributes is time. We remove this attribute to avoid triviality and retain only the speed attribute. The specification of rush hours may not be highly accurate, but it is a sensible practice to cope with the nonexistence of labeling information. Intuitively, the signal of rush hour comes from reduced traffic speed, but not every location of the network experiences traffic jam. Hence, the diverse traffic patterns inside the same time window under a single label causes nontrivial challenges for local models to discern. Therefore, the need of a global consensus model is justified and it fits well the federated inference scenario. PMU-B and PMU-C. These are proprietary data sets coordinately provided by multiple data owners of the U.S. power grid. No personally identifiable information is present. The suffixes B and C indicate the interconnects of the grid. The data sets come with thousands of annotated grid events spanning a period of two years; they form the classification labels. Many variables (attributes) of the grid condition are recorded; we select only the voltage magnitude and the current magnitude, because they appear to be the strongest signals for event detection based on domain knowledge, and also because more data are available for these two variables. The grid topology is not available. For each event, we select a one-second window from the three-minute window that covers the approximate annotated event time, based on the largest z-score. We retain a sampling frequency of 30Hz, even though some data are 60Hz. Furthermore, a large amount of data are missing in the raw data. We impute the series by using pandas.DataFrame.interpolate(method = 'linear', limit direction = 'both') from the Python pandas package. This way, a windowed series is complete if it ever has raw data. Even so, many series are entirely empty, which corresponds to the scenario illustrated by Figure 1 . Classes in these two data sets are rather skewed. For PMU-B, we remove a class that consists of only one data point and for PMU-C, we combine classes that contain fewer than 24 data points into a single class.

G EXPERIMENT DETAILS

The experiments are conducted on one x86 node of a computing cluster with one a100 NVIDIA GPU. The compute node has eight Intel cores and 128GB memory. For each data set, we perform a 70/10/20 random split for training, validation, and testing, respectively. For local models, we use LSTM with the same hyperparameters: one hidden layer whose hidden dimension is 16 and the maximum number of epochs = 200. We pre-train the local models and freeze their parameters afterward. We train each global model for a maximum of 500 epochs and use early stopping according to the validation loss, with a patience of 50 epochs. For the GNN global model, we use a 2-layer GCN with skip connections. The hidden dimension is set at 8 and we select the learning rate from {0.01, 0.001}. For missing data, we impute the node features by using zero.

H SOFT AND HARD FEATURE ALIGNMENT

Feature alignment can be achieved in two manners. The first approach is a soft alignment, which treats each P i a free parameter matrix to optimize. Such an alignment softens the one-to-one correspondence in the permutation constraint; i.e., each feature in the source can have a weighted correspondence to each of the features in the target. That is the way we used in the main paper. An alternative approach is a hard alignment, which treats each P i a permutation matrix. Learning permutation matrices is challenging, however, because they correspond to combinatorial structures and are unsuitable for gradient-based training. We follow Mena et al. (2018) ; Emami & Ranka (2018) and relax P i by a doubly stochastic matrix, which can be differentiably parameterized by the Sinkhorn-Knopp algorithm Sinkhorn & Knopp (1967) . Specifically, starting from a nonnegative square matrix K 0 and column vectors r 0 = c 0 = 1 of matching lengths, define the sequence c j+1 = 1 ⊘ (K T 0 r j ) and r j+1 = 1 ⊘ (K 0 c j ), for j = 0, 1, . . . Then, under a mild condition, K j := diag(r j )K 0 diag(c j ) converges to a doubly stochastic matrix. We truncate the sequence at the T th step and treat K T as an approximation of P i . Despite the advocation by Mena et al. (2018) ; Emami & Ranka (2018) , we obtain the following convergence result of Sinkhorn-Knopp, which reveals no free lunch. The plots clearly show patterns of a permutation matrix: there is one and only one significant value per row and per column. Because of the slow convergence, we attribute the desirable results of K T (at a small T ) to the success of the learning of K 0 . Note also interestingly that a learned permutation may be the identity mapping. Theorem 5 (informal). Under a condition of K 0 , there exists a positive integer J and a constant C J such that for all j ≥ J, K T j 1 K j 1 - 1 1 ≤ C J (1 + σ 2 2 )σ 2(j-J) 2 , where σ 2 ≤ 1 is the second largest singular value of the limit of K j . Since this is not the focus of this paper, we omit the rigorous analysis of this theorem. The result suggests that for a desirable limit being a permutation matrix, whose σ 2 = 1, the error O(σ 2j 2 ) does not drop. In practice, to expect for an approximate permutation matrix, σ 2 ≈ 1 and the convergence is exceedingly slow. The practical usefulness of (20) depends on the learned quality of K 0 . The soft and hard alignment approaches have pros and cons. The hard approach maintains the correspondence of each feature dimension of the latent vectors while the soft approach . Maintaining the dimension correspondence is an advantage, especially for local models that produce disentangled latent representations Higgins et al. (2018) , because each feature dimension is equipped with a semantic meaning that controls a certain aspect of the data. On the other hand, the soft approach is more straightforward and the hard approach is based on an algorithm that barely converges. In practice, we observe that two approaches decisively similar performance. Due to space limitation we took the simpler approach and presented only the soft version in the main paper, but here we list the results for both approaches in Table 5 , and we also visualize the hard alignment matrix learned on each dataset to help readers understand the feature alignment (Figure 4 ).



Note that in our case, synchronicity requires co-training among clients which is a weaker constraint than its usual meaning of further requiring clients to synchronize their updates per iteration. Note that the approach proposed byHu et al. (2019) assumes no parameters for the global model. Were global parameters present, gradient communication is inevitable.



Figure 1: Federated Feature Fusion: A global prediction is produced collectively based on a set of global features which are the result of fusing local feature representations supplied by the data owners. These feature representations are induced from locally trained models on raw local data which might be heterogeneous.

Figure 2: Federated Feature Fusion Framework. Local models are trained independently and separately from the global model. The algorithm is summarized in Algorithm 1.

Figure 3: Distributions of prediction entropy across local models.

Our framework also advocates learning a graph of data owners in the global model. Graph structure learning appears under various contexts. One field of study is probabilistic graphical models and casual inference, whereby a directed acyclic structure is learned. Gradient-based approaches in this context includeZheng et al. (2018);Yu et al. (2019);Lachapelle et al. (2020). On the other hand, a general graph may still be useful without resorting to causality. Recent approaches supporting GNNbased modeling includeKipf et al. (2018);Franceschi et al. (2019);Wu et al. (2020);Shang et al.

e m (1 + e m ) 2 dm. Denote h(τ, m) = [1 + βe mτ ] -1 . Treating h a function of τ and performing Taylor expansion around zero, we obtain

Figure 4: Examples of learned permutation matrices (K T ).The plots clearly show patterns of a permutation matrix: there is one and only one significant value per row and per column. Because of the slow convergence, we attribute the desirable results of K T (at a small T ) to the success of the learning of K 0 . Note also interestingly that a learned permutation may be the identity mapping.

Datasets.

Effectiveness of latent alignment in a graph-based global model. Superscript numbers are standard deviations. Note that baseline (A) and (H) are not applicable to the federated feature fusion where (local) model homogenization and training synchronicity are not allowed. Best Model Selection .53 .000 .70 .000 .55 .000 .79 .000 .37 .000 .69 .000 .32 .000 .62 .000 E: Mean Pooling -Eq. (5) .77 .009 .96 .004 .74 .012 .93 .001 .38 .008 .71 .006 .34 .008 .64 .010 F: Transformer .78 .023 .94 .018 .72 .045 .93 .027 .39 .003 .70 .009 .40 .053 .67 .058 G: Concatenation .83 .008 .97 .002 .80 .066 .96 .028 .39 .006 .71 .036 .40 .025 .68 .040 H: F 3 with no alignment .80 .009 .96 .004 .75 .009 .94 .001 .39 .003 .73 .015 .40 .020 .66 .018 J: F 3 with parameter tying .82 .009 .97 .001 .75 .009 .94 .004 .39 .006 .72 .010 .37 .012 .66 .008 K: F 3 with alignment .83 .010 .97 .001 .86 .005 .98 .002 .39 .008 .73 .008 .45 .015 .72 .003 L: VFL w. graph/alignment .83 .012 .97 .001 .86 .014 .98 .002 .39 .006 .74 .009 .45 .015 .73 .003 M: VFL w.o. pretrained local .77 .02 .94 .021 .77 .014 .95 .006 .34 .014 .69 .012 .35 .008 .65 .014

Impact of learning graph across different alignment settings. ⋆ Some references of rows are with respect to Table 2. .009 .957 .004 .738 .012 .935 .001 .381 .008 .711 .006 .342 .008 .636 .010 Given Graph .763 .020 .957 .007 .742 .024 .942 .005 Graph .715 .015 .952 .004 .695 .013 .934 .004 .372 .001 .711 .013 .404 .016 .68 .014 ICDF .798 .009 .963 .004 .755 .009 .943 .001 .387 .003 .734 .015 .403 .020 .663 .018 .009 .970 .002 .846 .008 .977 .001 .386 .009 .725 .012 .386 .008 .694 .005 Given Graph .828 .007 .974 .001 .854 .003 .977 .001 Graph .803 .020 .968 .002 .855 .003 .973 .002 .378 .002 .718 .015 .418 .007 .702 .009 ICDF .835 .010 .975 .001 .860 .005 .980 .002 .390 .008 .734 .008 .451 .015 .725 .003

Time and memory consumption of F 3 (five epochs) with respect to ICDF and Gumbel-Softmax re-parameterization. Time is in seconds and memory is in MB.METR-LA and PEMS-BAY. These are traffic data sets (MIT licensed) used byLi et al. (2018). The former was collected from loop detectors in the highway of Los Angles, CA(Jagadish et al., 2014) and the latter was collected by the California Transportation Agencies Performance Measure System. Both data sets recorded several months of data at the resolution of five minutes. The network graphs are available, which were constructed by imposing a radial basis function on the pairwise distance of sensors at a certain cutoff.

Different approaches to feature alignment. alignment .835 .975 .860 .980 .390 .734 .451 .725  hard alignment .839 .973 .855 .976 .390 .737 .429 .721

annex

One can arbitrarily permute the elements of h through manipulating the GRU parameters properly. To achieve h[p] = embedding(x; p),• the gate outputs and bias vectors (z t , r t , n t , h t , b z , b r , b n ) will need be permuted accordingly• the weight matrices attached to the input (W z , W r , W n ) will need to have their rows (i.e., output neurons) permuted (W z [p, :], W r [p, :], W n [p, :]); and• the weight matrices attached to the hidden states (U z , U r , U n ) will need to have both their rows and columns permutedC PROOFS AND ADDITIONAL RESULTS OF THEOREM 1, SECTION 5C.1 DISTRIBUTION OF GUMBEL SOFTMAXThe Gumbel softmax reparameterization trick (Jang et al., 2017; Maddison et al., 2017) works in the following manner. Let Cat(π) be the categorical distribution with probability vector π and let g, of the same shape as π, be a vector variable whose elements are i.i.d. ∼ Gumbel(0, 1). Then,admits a distribution converging to Cat(π) when τ → 0. Hence, to sample Ber(θ) approximately but differentiably, it suffices to let π = [θ, 1 -θ] ⊤ and use y 1 as the sample.As preliminary, we consider the first entry y 1 of the random variable y defined in (11) for the Gumbel softmax parameterization. Note that for any τ ̸ = 0, y 1 is only approximately binary; the possible values of y 1 in fact span the entire interval [0, 1]. We derive the following CDF for y 1 . Recall that for notational simplicity, θ denotes a scalar rather than a matrix.Theorem 2. For all τ > 0, θ ∈ (0, 1), and t ∈ [0, 1], we haveProof. We first consider the case 0 < t < 1. Through simple algebraic manipulation, we obtain that y 1 ≤ t is equivalent toLet g 1 = -log(-log u) and g 2 = -log(-log v), where u and v are independent and ∼ U(0, 1). Then, ( 13) is equivalent toTherefore, by recalling that u and v are uniform in [0, 1] 2 , we note that the probability that v ≤ u M happens is the double integralThis integral is nothing butwhich completes the proof of ( 12). The cases of t = 0 or 1 obviously hold by continuity.

