LEARNING GROUP IMPORTANCE USING THE DIFFEREN-TIABLE HYPERGEOMETRIC DISTRIBUTION

Abstract

Partitioning a set of elements into subsets of a priori unknown sizes is essential in many applications. These subset sizes are rarely explicitly learned -be it the cluster sizes in clustering applications or the number of shared versus independent generative latent factors in weakly-supervised learning. Probability distributions over correct combinations of subset sizes are non-differentiable due to hard constraints, which prohibit gradient-based optimization. In this work, we propose the differentiable hypergeometric distribution. The hypergeometric distribution models the probability of different group sizes based on their relative importance. We introduce reparameterizable gradients to learn the importance between groups and highlight the advantage of explicitly learning the size of subsets in two typical applications: weakly-supervised learning and clustering. In both applications, we outperform previous approaches, which rely on suboptimal heuristics to model the unknown size of groups.

1. INTRODUCTION

Many machine learning approaches rely on differentiable sampling procedures, from which the reparameterization trick for Gaussian distributions is best known (Kingma & Welling, 2014; Rezende et al., 2014) . The non-differentiable nature of discrete distributions has long hindered their use in machine learning pipelines with end-to-end gradient-based optimization. Only the concrete distribution (Maddison et al., 2017) or Gumbel-Softmax trick (Jang et al., 2016) boosted the use of categorical distributions in stochastic networks. Unlike the high-variance gradients of score-based methods such as REINFORCE (Williams, 1992) , these works enable reparameterized and lowvariance gradients with respect to the categorical weights. Despite enormous progress in recent years, the extension to more complex probability distributions is still missing or comes with a trade-off regarding differentiability or computational speed (Huijben et al., 2021) . The hypergeometric distribution plays a vital role in various areas of science, such as social and computer science and biology. The range of applications goes from modeling gene mutations and recommender systems to analyzing social networks (Becchetti et al., 2011; Casiraghi et al., 2016; Lodato et al., 2015) . The hypergeometric distribution describes sampling without replacement and, therefore, models the number of samples per group given a limited number of total samples. Hence, it is essential wherever the choice of a single group element influences the probability of the remaining elements being drawn. Previous work mainly uses the hypergeometric distribution implicitly to model assumptions or as a tool to prove theorems. However, its hard constraints prohibited integrating the hypergeometric distribution into gradient-based optimization processes. In this work, we propose the differentiable hypergeometric distribution. It enables the reparameterization trick for the hypergeometric distribution and allows its integration into stochastic networks of modern, gradient-based learning frameworks. In turn, we learn the size of groups by modeling their relative importance in an end-to-end fashion. First, we evaluate our approach by performing a Kolmogorov-Smirnov test, where we compare the proposed method to a non-differentiable reference implementation. Further, we highlight the advantages of our new formulation in two different applications, where previous work failed to learn the size of subgroups of samples explicitly. Our first application is a weakly-supervised learning task where two images share an unknown number of generative factors. The differentiable hypergeometric distribution learns the number of shared and independent generative factors between paired views through gradient-based optimization. In contrast, previous work has to infer these numbers based on heuristics or rely on prior knowledge about the connection between images. Our second application integrates the hypergeometric distribution into a variational clustering algorithm. We model the number of samples per cluster using an adaptive hypergeometric distribution prior. By doing so, we overcome the simplified i.i.d. assumption and establish a dependency structure between dataset samples. The contributions of our work are the following: i) we introduce the differentiable hypergeometric distribution, which enables its use for gradient-based optimization, ii) we demonstrate the accuracy of our approach by evaluating it against a reference implementation, and iii) we show the advantages of explicitly learning the size of groups in two different applications, namely weakly-supervised learning and clustering.

2. RELATED WORK

In recent years, finding continuous relaxations for discrete distributions and non-differentiable algorithms to integrate them into differentiable pipelines gained popularity. Jang et al. (2016) and Maddison et al. (2017) concurrently propose the Gumbel-Softmax gradient estimator. It enables reparameterized gradients with respect to parameters of the categorical distribution and their use in differentiable models. Methods to select k elements -instead of only one -are subsequently introduced. Kool et al. (2019; 2020a) implement sequential sampling without replacement using a stochastic beam search. Kool et al. (2020b) extend the sequential sampling procedure to a reparameterizable estimator using REINFORCE. Grover et al. (2019) propose a relaxed version of a sorting procedure, which simultaneously serves as a differentiable and reparameterizable top-k element selection procedure. Xie & Ermon (2019) propose a relaxed subset selection algorithm to select a given number k out of n elements. Paulus et al. (2020) generalize stochastic softmax tricks to combinatorial spaces. 1Unlike Kool et al. (2020b) , who also use a sequence of categorical distributions, the proposed method describes a differentiable reparameterization for the more complex but well-defined hypergeometric distribution. Differentiable reparameterizations of complex distributions with learnable parameters enable new applications, as shown in Section 5. The classical use case for the hypergeometric probability distribution is sampling without replacement, for which urn models serve as the standard example. The hypergeometric distribution has previously been used as a modeling distribution in simulations of social evolution (Ono et al., 2003; Paolucci et al., 2006; Lashin et al., 2007) , tracking of human neurons and gene mutations (Lodato et al., 2015; 2018) , network analysis (Casiraghi et al., 2016) , and recommender systems (Becchetti et al., 2011) . Further, it is used as a modeling assumption in submodular maximization (Feldman et al., 2017; Harshaw et al., 2019) , multimodal VAEs (Sutter & Vogt, 2021) , k-means clustering variants (Chien et al., 2018) , or random permutation graphs (Bhattacharya & Mukherjee, 2017) . Despite not being differentiable, current sampling schemes for the multivariate hypergeometric distribution are a trade-off between numerical stability and computational efficiency (Liao & Rosen, 2001; Fog, 2008a; b) .

3. PRELIMINARIES

Suppose we have an urn with marbles in different colors. Let c ∈ N be the number of different classes or groups (e.g. marble colors in the urn), m = [m 1 , . . . , m c ] ∈ N c describe the number of elements per class (e.g. marbles per color), N = c i=1 m i be the total number of elements (e.g. all marbles in the urn) and n ∈ {0, . . . , N } be the number of elements (e.g. marbles) to draw. Then, the multivariate hypergeometric distribution describes the probability of drawing x = [x 1 , . . . , x c ] ∈ N c 0 marbles by sampling without replacement such that c i=1 x i = n, where x i is the number of drawn marbles of class i. Using the central hypergeometric distribution, every marble is picked with equal probability. The number of selected elements per class is then proportional to the ratio between number of elements per class and the total number of elements in the urn. This assumption is often too restrictive, and we want an additional modeling parameter for the importance of a class. Generalizations, which make certain classes more likely to be picked, are called noncentral hypergeometric distributions. In the literature, two different versions of the noncentral hypergeometric distribution exist, Fisher's (Fisher, 1935) and Wallenius' (Wallenius, 1963; Chesson, 1976) distribution. Due to limitations of the latter (Fog, 2008a) , we will refer to Fisher's version of the noncentral hypergeometric distribution in the remaining part of this work. Definition 3.1 (Multivariate Fisher's Noncentral Hypergeometric Distribution (Fisher, 1935) ). A random vector X follows Fisher's noncentral multivariate distribution, if its joint probability mass function is given by P (X = x; ω) = p X (x; ω) = 1 P 0 c i=1 m i x i ω xi i ( ) where P 0 = y∈S c i=1 mi yi ω yi i . The support S of the PMF is given by S = {x ∈ N c 0 : ∀i x i ≤ m i , c i=1 x i = n} and n k = n! k!(n-k)! . The total number of samples per class m, the number of samples to draw n, and the class importance ω parameterize the multivariate distribution. Here, we assume m and n to be constant per experiment and are mainly interested in the class importance ω. Consequently, we only use ω as the distribution parameter in Equation ( 1) and the remaining part of this work. The class importance ω is a crucial modeling parameter in applying the noncentral hypergeometric distribution (see (Chesson, 1976) ). It resembles latent factors like the importance, fitness, or adaptation capabilities of a class of elements, which are often more challenging to measure in field experiments than the sizes of different populations. Introducing a differentiable and reparameterizable formulation enables the learning of class importance from data (see Section 5). We provide a more detailed introduction to the hypergeometric distribution in Appendix A.

4. METHOD

The reparameterizable sampling for the proposed differentiable hypergeometric distribution consists of three parts: 1. Reformulate the multivariate distribution as a sequence of interdependent and conditional univariate hypergeometric distributions. 2. Calculate the probability mass function of the respective univariate distributions. 3. Sample from the conditional distributions utilizing the Gumbel-Softmax trick. We explain all steps in the following Sections 4.1 to 4.3. Additionally, Algorithm 1 and Algorithm 2 (see Appendix B.5) describe the full reparameterizable sampling method using pseudo-code and Figures 5 and 6 in the appendix illustrate it graphically.

4.1. SEQUENTIAL SAMPLING USING CONDITIONAL DISTRIBUTIONS

Because it scales linearly with the number of classes and not with the size of the support S (see Definition 3.1), we use the conditional sampling algorithm (Liao & Rosen, 2001; Fog, 2008b) . Following the chain rule of probability, we sample from the following sequence of conditional probability distributions p X (x; ω) = p X1 (x 1 ; ω) c i=2 p Xi (x i |{ j<i x j }; ω) Following Equation (2), every p Xi (•) describes the probability of a single class i of samples given the already sampled classes j < i. In the conditional sampling method, we model every conditional distribution p Xi (•) as a univariate hypergeometric distribution with two classes L and R: for i = 1..c, we define class L := {i} as the left class and class R := {j : j > i ∧ j ≤ c} as the right class. Algorithm 1 Sampling from the differentiable hypergeometric distribution. The different blocks are explained in more detail in Sections 4.1 to 4.3 and Algorithm 2. Input: m ∈ N c , ω ∈ R c 0+ , n ∈ N, τ ∈ R 0+ Output: x ∈ N c 0 , {α i ∈ R mi } c i=1 , {r i ∈ R mi } c i=1 for i ∈ {1, . . . , c} do L ← i, R ← { c j=i+1 j} # Formulate the multivariate as a univariate m → m L , m R ∈ Z 0+ , ω → ω L , ω R ∈ R 0+ # distribution (Section 4.1) x L , α L , rL ←sampleUNCHG(m L , m R , ω L , ω R , n, τ ) # Sample from univariate distribution n ← n -x L , m ← m \ m L , ω ← ω \ ω L # Re-assign classes for next step x i ← x L , α i ← α L , ri ← rL # Assign values for class i end for return x, {α i } c i=1 , {r i } c i=1 function SAMPLEUNCHG(m i , m j , ω i , ω j , n, τ ) α i ← calcLogPMF(m i , m j , ω i , ω j , n) # Section 4.2 x i , ri ← contRelaxSample(α i , τ )) # Section 4.3 return x i , α i , ri end function To sample from the original multivariate hypergeometric distribution, we sequentially sample from the urn with only two classes L and R, which simplifies to sampling from the univariate noncentral hypergeometric distribution given by the following parameters (Fog, 2008b) : m L = l∈L m l , m R = r∈R m r , ω L = l∈L ω l • m l m L , ω R = r∈R ω r • m r m R We leave the exploration of different and more sophisticated subset selection strategies for future work. Samples drawn using this algorithm are only approximately equal to samples from the joint noncentral multivariate distribution with equal ω. Because of the merging operation in Equation (3), the approximation error is only equal to zero for the central hypergeometric distribution. One way to reduce this approximation error independent of the learned ω is a different subset selection algorithm (Fog, 2008b) . Note that the proposed method introduces an approximation error compared to a non-differentiable reference implementation with the same ω (see Section 5.1), but not the underlying and desired true class importance. We can still recover the true class importance because a different ω overcomes the approximation error introduced by merging needed for the conditional sampling procedure.

4.2. CALCULATE PROBABILITY MASS FUNCTION

In Section 4.1, we derive a sequential sampling procedure for the hypergeometric distribution, in which we repeatedly draw from a univariate distribution to simplify sampling. Therefore, we only need to compute the PMF for the univariate distribution. See Appendix B.1, for the multivariate extension. The PMF p X L (x L ; ω) for the hypergeometric distribution of two classes L and R defined by m L , m R , ω L , ω R and n is given as p X L (x L ; ω) = 1 P 0 m L x L ω x L L m R n -x L ω n-x L R ( ) P 0 is defined as in Equation (1), ω L , ω R and their derivation from ω, and m L , m R , n as in Equation (3). The large exponent of ω and the combinatorial terms can lead to numerical instabilities making the direct calculation of the PMF in Equation (4) infeasible. Calculations in log-domain increase numerical stability for such large domains, while keeping the relative ordering. Lemma 4.1. The unnormalized log-probabilities log p X L (x L ; ω) =x L log ω L + (n -x L ) log ω R + ψ F (x L ) + C (5) define the unnormalized weights of a categorical distribution that follows Fisher's noncentral hypergeometric distribution. C is a constant and ψ F (x L ) is defined as ψ F (x L ) = -log (Γ(x L + 1)Γ(n -x L + 1)) -log (Γ(m L -x L + 1)Γ(m R -n + x L + 1)) (6) We provide the proof to Lemma 4.1 in Appendix B.2. Common automatic differentiation frameworksfoot_1 have numerically stable implementations of log Γ(•). Therefore, Equation (6) and more importantly Equation ( 5) can be calculated efficiently and reliably. Lemma 4.1 relates to the calcLogPMF function in Algorithm 1, and Algorithm 2 describes calcLogPMF in more detail. Using the multivariate form of Lemma 4.1 (see Appendix B.1), it is possible to directly calculate the categorical weights for the full support S. Compared to the conditional sampling procedure, this would result in a computational speed-up for a large number of classes c. However, the size of the support S is c i=1 m i , quickly resulting in unfeasible memory requirements. Therefore, we would restrict ourselves to settings with no practical relevance.

4.3. CONTINUOUS RELAXATION FOR THE CONDITIONAL DISTRIBUTION

Continuous relaxations describe procedures to make discrete distributions differentiable with respect to their parameters (Huijben et al., 2021) . Following Lemma 4.1, we make use of the Gumbel-Softmax trick to reparameterize the hypergeometric distribution via its conditional distributions p X L (•). The Gumbel-Softmax trick enables a reparameterization of categorical distributions that allows the computation of gradients with respect to the distribution parameters. We state Lemma 4.2, and provide a proof in Appendix B.3. Lemma 4.2. The Gumbel-Softmax trick can be applied to the conditional distribution p Xi (x i |{x k } i-1 k=1 ; ω) of class i given the already sampled classes k < i. Lemma 4.2 connects the Gumbel-Softmax trick to the hypergeometric distribution. Hence, reparameterizing enables gradients with respect to the parameter ω of the hypergeometric distribution: u ∼ U (0, 1), g i = -log(-log(u)), ri = α i (ω) + g i where u ∈ [0, 1] mi+1 is a random sample from an i.i.d. uniform distribution U . g i is therefore i.i.d. gumbel noise. ri are the perturbed conditional probabilities for class i given the class conditional unnormalized log-weights α i (ω): α i (ω) = log p Xi (x i ; ω) -C = [log p Xi (0; ω), . . . , log p Xi (m i ; ω)] -C We use the softmax function to generate (m i + 1)-dimensional sample vectors from the perturbed unnormalized weights ri /τ , where τ is the temperature parameter. Due to Lemma 4.2, we do not need to calculate the constant C in Equations ( 5) and (8). We refer to the original works (Jang et al., 2016; Maddison et al., 2017) or Appendix A.2 for more details on the Gumbel-Softmax trick itself. Lemma 4.2 corresponds to the contRelaxSample function in Algorithm 1 (see Algorithm 2 for more details). Note the difference between the categorical and the hypergeometric distribution concerning the Gumbel-Softmax trick. Whereas the (unnormalized) category weights are also the distribution parameters for the former, the log-weights α i of class i are a function of the class importance ω and the pre-defined x i = [0, . . . , m i ] for the latter. It follows that for a sequence of categorical distributions, we would have c i=1 m i learnable parameters, whereas for the differentiable hypergeometric distribution we only have c learnable parameters. In Appendix B.4, we discuss further difficulties in using a sequence of unconstrained differentiable categorical distributions.

5. EXPERIMENTS

We perform three experiments that empirically validate the proposed method and highlight the versatility and applicability of the differentiable hypergeometric distribution to different important areas of machine learning. We first test the generated samples of the proposed differentiable formulation procedure against a non-differentiable reference implementation. Second, we present how the hypergeometric distribution helps detecting shared generative factors of paired samples in a weakly-supervised setting. Our third experiment demonstrates the hypergeometric distribution as a prior in variational clustering algorithms. In our experiments we focus on applications in the field of distribution inference where we make use of the reparameterizable gradients. Nevertheless, the proposed method is applicable to any application with gradient-based optimization, in which the underlying process models sampling without replacementfoot_2 .

5.1. KOLMOGOROV-SMIRNOV TEST

To assess the accuracy of the proposed method, we evaluate it against a reference implementation using the Kolmogorov-Smirnov test (Kolmogorov, 1933; Smirnov, 1939, KS) . It is a nonparametric test to estimate the equality of two distributions by quantifying the distance between the empirical distributions of their samples. The null distribution of this test is calculated under the null hypothesis that the two groups of samples are drawn from the same distribution. If the test fails to reject the null hypothesis, the same distribution generated the two groups of samples, i.e., the two underlying distributions are equal. As described in Section 4, we use class conditional distributions to sample from the differentiable hypergeometric distribution. We compare samples from our differentiable formulation to samples from a non-differentiable reference implementation (Virtanen et al., 2020, SciPy) . For this experiment, we use a multivariate hypergeometric distribution of three classes. We perform a sensitivity analysis with respect to the class weights ω. We keep ω 1 and ω 3 fixed at 1.0, and ω 2 is increased from 1.0 to 10.0 in steps of 1.0. For every value of ω 2 , we sample 50000 i.i.d. random vectors. We use the Benjamini-Hochberg correction (Benjamini & Hochberg, 1995) to adjust the p-values for false discovery rate of multiple comparisons as we are performing c = 3 tests per joint distribution. Given a significance threshold of t = 0.05, p > 0.05 implies that we cannot reject the null hypothesis, which is desirable for our application. Figure 1a shows the histogram of class 2 samples for all values of ω 2 , and Figure 1b the results of the KS test for all classes. The histograms for classes 1 and 3 are in the Appendix (Figure 4 ). We see that the calculated distances of the KS-test are small, and the corrected p-values well above the threshold. Many are even close to 1.0. Hence, the test clearly fails to reject the null hypothesis in 30 out of 30 cases. Additionally, the proposed and the reference implementation histograms are visually similar. The results of the KS test strongly imply that the proposed differentiable formulation effectively follows a noncentral hypergeometric distribution. We provide more analyses and results from KS test experiments in Appendix C.1.

5.2. WEAKLY-SUPERVISED LEARNING

Many data modalities, such as consecutive frames in a video, are not observed as i.i.d. samples, which provides a weak-supervision signal for representation learning and generative models. Hence, we are not only interested in learning meaningful representations and approximating the data distribution but also in the detailed relation between frames. Assuming underlying factors generate such coupled Table 1 : To compare the three methods (LabelVAE, AdaVAE, HGVAE) in the weakly-supervised experiment, we evaluate their learned latent representations with respect to shared (S) and independent (I) generative factors. To assess the amount of shared and independent information in the latent representation, we train linear classifiers on the respective latent dimensions only. We report the adjusted balanced classification accuracy, such that the random classifier achieves score 0. data, a subset of factors should be shared among frames to describe the underlying concept leading to the coupling. Consequently, differing concepts between coupled frames stem from independent factors exclusive to each frame. The differentiable hypergeometric distribution provides a principled approach to learning the number of shared and independent factors in an end-to-end fashion. In this experiment, we look at pairs of images from the synthetic mpi3D toy dataset (Gondal et al., 2019) . We generate a coupled dataset by pairing images, which share a certain number of their seven generative factors. We train all models as variational autoencoders (Kingma & Welling, 2014, VAE) to maximize an evidence lower bound (ELBO) on the marginal log-likelihood of images using the setting and code from Locatello et al. (2020) . We compare three methods, which sequentially encode both images to some latent representations using a single encoder. Based on the two representations, the subset of shared latent dimensions is aggregated into a single representation. Finally, a single decoder independently computes reconstructions for both samples based on the imputed shared and the remaining independent latent factors provided by the respective encoder. The methods only differ in how they infer the subset of shared latent factors. We refer to Appendix C.2 or Locatello et al. (2020) for more details on the setup of the weakly-supervised experiment. The LabelVAE assumes that the number of independent factors is known (Bouchacourt et al., 2018; Hosoya, 2018, LabelVAE) . Like in Locatello et al. (2020) , we assume 1 known factor for all experiments. The AdaVAE relies on a heuristic to infer the shared factors (Locatello et al., 2020, AdaVAE) . Based on the Kullback-Leibler (KL) divergence between corresponding latent factors across images, the decision between shared and independent factors is based on a hand-designed threshold. The proposed HGVAE uses the differentiable hypergeometric distribution to model the number of shared and independent latent factors. We infer the unknown group importance ω ∈ R 2 0+ with a single dense layer, which uses the KL divergences between corresponding latent factors as input (similar to AdaVAE). Based on ω and with d being the number of latent dimensions, we sample random variables to estimate the k independent and s := d -k shared factors by reparameterizing the hypergeometric distribution, providing us with k and ŝ := d-k. The proposed differentiable formulation allows us to infer such ω and simultaneously learn the latent representation in a fully differentiable setting. After sorting the latent factors by KL divergence (Grover et al., 2019) , we define the top-k latent factors as independent, and the remaining ŝ as shared ones. For a more detailed description of HGVAE, the baseline models, and the dataset, see Appendix C.2. To evaluate the methods, we compare their performance on two different tasks. We measure the mean squared error MSE(s, ŝ) between the actual number of shared latent factors s and the estimated number ŝ (Figure 2 ) and the classification accuracy of predicting the generative factors on the shared and independent subsets of the learned representations (Table 1 ). We train classifiers for all factors on the shared and independent part of the latent representation and calculate their balanced accuracy. The reported scores are the average over the factor-specific balanced accuracies. Because the number of different classes differs between the discrete generative factors, we report the adjusted balanced accuracy as a classification metric. These two tasks challenge the methods regarding their estimate of the relationship between images. Generating the dataset controls the number of independent and shared factors, k and s, which allows us to evaluate the methods on different versions of the same underlying data regarding the number of shared and independent generative factors. We generate four weakly-supervised datasets with s = {0, 1, 3, 5}. On purpose, we also evaluate the edge case of s = 0, which is equal to the two views not sharing any generative factors. Figure 2 shows that previous methods cannot accurately estimate the number of shared factors. Both baseline methods estimate the same number of shared factors ŝ independent of the underlying ground truth number of shared factors s. What is not surprising for the first model is unexpected for the second approach, given their -in theory -adaptive heuristic. On the other hand, the low mean squared error (MSE) reflects that the proposed HGVAE can dynamically estimate the number of shared factors for every number of shared factors s. These results suggest that the differentiable hypergeometric distribution is able to learn the relative importance of shared and independent factors in the latent representations. Previous works' assumptions, though, do not reflect the data's generative process (LabelVAE), or the designed heuristics are oversimplified (AdaVAE). We also see the effect of these oversimplified assumptions in evaluating the learned latent representation. Table 1 shows that the estimation of a large number of shared factors ŝ leads to an inferior latent representation of the independent factors, which is reflected in the lower accuracy scores of previous work compared to the proposed method. More surprisingly, for the shared latent representation, our HGVAE reaches the same performance on the shared latent representation despite being more flexible. Given the general nature of the method, the positive results of the proposed method are encouraging. Unlike previous works, it is not explicitly designed for weakly-supervised learning but achieves results that are more than comparable to field-specific models. Additionally, the proposed method accurately estimates the latent space structure for different experimental settings.

5.3. DEEP VARIATIONAL CLUSTERING

We investigate the use of the differentiable hypergeometric distribution in a deep clustering task. Several techniques have been proposed in the literature to combine long-established clustering algorithms, such as K-means or Gaussian Mixture Models, with the flexibility of deep neural networks to learn better representations of high-dimensional data (Min et al., 2018) . Among those, Jiang et al. (2016) , Dilokthanakul et al. (2016) , and Manduchi et al. (2021) include a trainable Gaussian Mixture prior distribution in the latent space of a VAE. A Gaussian Mixture model permits a probabilistic approach to clustering where an explicit generative assumption of the data is defined. All methods are optimized using stochastic gradient variational Bayes (Kingma & Welling, 2014; Rezende et al., 2014) . A major drawback of the above models is that the samples are either assumed to be i.i.d. or they require pairwise side information, which limits their applicability in real-world scenarios. The differentiable hypergeometric distribution can be easily integrated in VAE-based clustering algorithms to overcome limitations of current approaches. We want to cluster a given dataset X = {x i } N i=1 into K subgroups. Like previous work (Jiang et al., 2016) , we assume the data is generated from a random process where the cluster assignments are first drawn from a prior probability p(c; π), then each latent embedding z i is sampled from a Gaussian distribution, whose mean and variance depend on the selected cluster c i . Finally the sample x i is generated from a Bernoulli distribution whose parameter µ xi is the output of a neural network parameterized by θ, as in the classical VAE. With these assumptions, the latent embeddings z i follow a mixture of Gaussian distributions, whose means and variances, {µ i , σ 2 i } K i=1 , are tunable parameters. The above generative model can then be optimised by maximising the ELBO using the stochastic gradient variational Bayes estimator (we refer to Appendix C.3 for a complete description of the optimisation procedure). Previous work (Jiang et al., 2016) modeled the prior distribution as p(c; π) = i p(c i ) = i Cat(c i | π) with either tunable or fixed parameters π. In this task, we instead replace this prior Table 2 : Evaluation of the clustering experiment on the MNIST datasets. We compare the methods on 3 different dataset versions, namely i) uniform class distribution ii) subsampling with 80% of samples and iii) subsampling with only 60% of samples. We subsample half of the classes. Accuracy (Acc), normalized mutual information (NMI), and adjusted rand index (ARI) are used as evaluation metrics. Higher is better for all metrics. Mean and standard deviations are computed across 5 runs. For fair comparison with the baselines all methods use the pretraining weights provided by Jiang et al. (2016) with the multivariate noncentral hypergeometric distribution with weights π and K classes where every class relates to a cluster. Hence, we sample the number of elements per cluster (or cluster size) following Definition 3.1 and Algorithm 1. The hypergeometric distribution permits to create a dependence between samples. The prior probability of a sample to be assigned to a given cluster is not independent of the remaining samples anymore, allowing us to loosen the over-restrictive i.i.d. assumption. We explore the effect of three different prior probabilities in Equation ( 54), namely (i) the categorical distribution, by setting p(c; π) = i Cat(c i | π); (ii) the uniform distribution, by fixing π i = 1/K ∀ i ∈ {1, . . . , K}; and (iii) the multivariate noncentral hypergeometric distribution. We compare them on three different MNIST versions (LeCun & Cortes, 2010) . The first version is the standard dataset which has a balanced class distribution. For the second and third dataset version we explore different ratios of subsampling for half of the classes. The subsampling rates are 80% in the moderate and 60% in the severe case. In Table 2 we evaluate the methods with respect to their clustering accuracy (Acc), normalized mutual information (NMI) and adjusted rand index (ARI). The hypergeometric prior distribution shows fairly good clustering performance in all datasets. Although the uniform distribution performs reasonably well, it assumes the clusters have equal importance. Hence it might fail in more complex scenarios. On the other hand, the categorical distribution has subpar performance compared to the uniform distribution, even in the moderate subsampling setting. This might be due to the additional complexity given by the learnable cluster weights, which results in unstable results. On the contrary, the additional complexity does not seem to affect the performance of the proposed hypergeometric prior but instead boosts its clustering performance, especially in the imbalanced dataset. In Figure 3 , we show that the model is able to learn the weights, π, which reflect the subsampling rates of each cluster, which is not directly possible using the uniform prior model.

6. CONCLUSION

We propose the differentiable hypergeometric distribution in this work. In combination with the Gumbel-Softmax trick, this new formulation enables reparametrized gradients with respect to the class weights ω of the hypergeometric distribution. We show the various possibilities of the hypergeometric distribution in machine learning by applying it to two common areas, clustering, and weaklysupervised learning. In both applications, methods using the hypergeometric distribution reach at least state-of-the-art performance. We believe this work is an essential step toward integrating the hypergeometric distribution into more machine learning algorithms. Applications in biology and social sciences represent potential directions for future work. Health and Related Technologies (PHRT)" of the ETH Domain (Swiss Federal Institutes of Technology). LM is supported by the SDSC PhD Fellowship #1-001568-037. AR is supported by the StimuLoop grant #1-007811-002 and the Vontobel Foundation.

7. ETHICS STATEMENT

In this work, we propose a general approach to learning the importance of subgroups. In that regard, potential ethical concerns arise with the different applications our method could be applied to. We intend to apply our model in the medical domain in future work. Being able to correctly model the dependencies of groups of patients is important and offers the potential of correctly identifying underlying causes of diseases on a group level. On the other hand, grouping patients needs to be handled carefully and further research is needed to ensure fairness and reliability with respect to underlying and hidden attributes of different groups.

8. REPRODUCIBILITY STATEMENT

For all theoretical statements, we provide detailed derivations and state the necessary assumptions. We present empirical support on both synthetic and real data to back our idea of introducing the differentiable hypergeometric distribution. To ensure empirical reproducibility, the results of each experiment and every ablation were averaged over multiple seeds and are reported with standard deviations. All of the used datasets are public or can be generated from publicly available resources using the code that we provide in the supplementary material. Information about implementation details, hyperparameter settings, and evaluation metrics are provided in the main text or the appendix. 

A PRELIMINARIES

A.1 HYPERGEOMETRIC DISTRIBUTION The hypergeometric distribution is a discrete probability distribution that describes the probability of x successes in n draws without replacement from a finite population of size N with m elements that are part of the success class Unlike the binomial distribution, which describes the probability distribution of x successes in n draws with replacement. Definition A.1 (Hypergeometric Distribution (Gonin, 1936)foot_3 ). A random variable X follows the hypergeometric distribution, if its probability mass function (PMF) is given by P (X = x) = p X (x) = m x N -m n-x N n Published as a conference paper at ICLR 2023 Urn models are typical examples of the hypergeometric probability distribution. Suppose we think of an urn with marbles in two different colors, e. g. green and purple, we can label as success the drawing of a green marble. Then N defines the total number of marbles and m the number of green marbles in the urn. x is the number of green marbles, and n -x is the number of drawn purple marbles. The multivariate hypergeometric distribution describes an urn with more than two colors, e.g. green, purple and yellow in the simplest case with three colors. As described in Johnson (1987) , the definition is given by: Definition A.2 (Multivariate Hypergeometric A random vector X follows the multivariate hypergeometric distribution, if its joint probability mass function is given by P (X = x) = p X (x) = c i=1 mi xi N n (10) where c ∈ N + is the number of different classes (e.g. marble colors in the urn), m = [m 1 , . . . , m c ] ∈ N c describes the number of elements per class (e.g. marbles per color), N = c i=1 m i is the total number of elements (e.g. all marbles in the urn) and n ∈ {0, . . . , N } is the number of elements (e.g. marbles) to draw. The support S of the PMF is given by S = x ∈ N c 0 : ∀i x i ≤ m i , c i=1 x i = n A.2 GUMBEL-SOFTMAX-TRICK Most of the information in this section is from (Jang et al., 2016; Maddison et al., 2017) , which concurrently introduced the Gumbel-Softmax trick. Gumbel-Softmax is a continuous distribution that can approximate the categorical distribution, and whose parameter gradients can be easily computed using the reparameterization trick. Let z be a categorical variable with categorical weights π = [π 1 , . . . , π C ] such that C k=1 π k = 1. Following (Jang et al., 2016) , we assume that categorical samples are encoded as one-hot vectors. The Gumbel-Max trick (Gumbel, 1954; Maddison et al., 2014) defines an efficient way to draw samples z from a categorical distribution with weights π: z = one hot(arg max k log(π k ) + g k ) where g = [g 1 , . . . , g C ] are i. i. d. samples drawn from Gumbel(0,1). We can efficiently sample g from Gumbel(0,1) by drawing a sample u from a uniform distribution U (0, 1) and applying the transform g = -log(-log u). For more details, we refer to (Gumbel, 1954; Maddison et al., 2014) . (Jang et al., 2016; Maddison et al., 2017) both use the softmax function as a continuous and differentiable approximation to arg max. The softmax function is defined as p k = exp((log π k + g k )/τ ) C j=1 exp((log π j + g j )/τ ) for k = 1, . . . , C where τ is a temperature parameter. As τ goes to zero, the softmax function approximates the argmax function. Hence, the Gumbel-Softmax distribution approximates the categorical distribution.

B.1 PMF FOR THE MULTIVARIATE FISHER'S NONCENTRAL DISTRIBUTION

In this section, we give a detailed derivation for the calculation of the log-probabilities of the multivariate Fisher's noncentral hypergeometric distribution. We end up with a formulation that is proportional to the actual log-probabilities. Because the ordering of categories is not influenced by scaling with a constant factor (addition/subtraction in log domain), these are unnormalized log-probabilities of the multivariate Fisher's noncentral hypergeometric distribution. p X (x; ω) = 1 P 0 c i=1 m i x i ω xi i ( ) where P 0 is defined as in Equation ( 1). From there it follows log p X (x; ω) = log 1 P 0 c i=1 m i x i ω xi i (15) = log 1 P 0 + log c i=1 m i x i ω xi i (16) = log 1 P 0 + c i=1 log m i x i ω xi i (17) = log 1 P 0 + c i=1 log m i x i + log (ω xi i ) (18) = log 1 P 0 + c i=1 log m i x i + x i log (ω i ) Constant factor can be removed as the argmax is invariant to scaling with a constant factor which equals addition or subtraction in log-space. It follows log p X (x; ω) = c i=1 log m i x i + x i log (ω i ) + C (20) = c i=1 log 1 x i !(m i -x i )! + x i log (ω i ) + C (21) = c i=1 (-log (Γ(x i + 1)Γ(m i -x i + 1)) + x i log (ω i )) + C (22) In the last line we used the relation Γ(k + 1) = k!. Setting C = C, it directly follows log p X (x; ω) = c i=1 x i log ω i + ψ F (x) + C where ψ F (x) = - c i=1 log (Γ(x i + 1)Γ(m i -x i + 1) ). The Gamma function is defined in Whittaker & Watson (1996) as Γ(z) = ∞ 0 x z-1 e -x dx B.2 PROOF FOR LEMMA 4.1 Proof. Factors that are constant for all x do not change the relative ordering between different values of x. Hence, removing them preserves the ordering of values x (Barrett, 2017) . Using the definition of the binomial coefficient (see Section 3) and its relation to the Gamma functionfoot_4  log p X L (x L ; ω) = log 1 P 0 m L x ω x L L m R n -x L ω n-x L R (26) = log m L x L + log m R n -x L + log (ω x L L ) + log ω n-x L R + C Γ(k + 1) = k!, it follows log p X L (x L ; ω) =x • log ω L + (n -x L ) • log ω R (28) -log (Γ(x L + 1)Γ(n -x L + 1)) -log (Γ(m L -x L + 1)Γ(m R -n + x L + 1)) + C With ψ F (x) as defined in Equation ( 6), Equation (5) follows directly.

B.3 PROOF FOR LEMMA 4.2

Proof. When sampling class i, we draw x i samples from class i where x i ≤ m i . The conditional distribution p Xi (x i |{x k } i-1 k=1 ; ω) for class i given the already sampled classes k < i simultaneously defines the weights of a categorical distribution. Sampling x i elements from class i can be seen as selecting the x i th category from the distribution defined by the weights p Xi (x i |{x k } i-1 k=1 ; ω). Therefore, 0≤xi≤mi p Xi x i |{x k } i-1 k=1 ; ω = 1, which allows us to apply the Gumbel-Max trick and, respectively, the Gumbel-Softmax trick.

B.4 THOUGHT EXPERIMENT FOR MODELING THE HYPERGEOMETRIC DISTRIBUTION WITH A SEQUENCE OF UNCONSTRAINED CATEGORICAL DISTRIBUTIONS

We quickly touched on the topic of using a sequence of categorical distributions in the main text (Section 4.3). To further describe and discuss the problem of using a sequence of categorical distributions, we provide a more detailed explanation and example here. Of course, the same constraints as the ones described in the main paper apply, e. g. we want our method to be differentiable, scalable in the number of random states and computationally efficient. Hence, we are interested in methods that infer at least all states of a single class in parallel, and not in methods that sequentially iterate over all possible random states for every re-sampling operation. In our example, we want to model a hypergeometric distribution with 3 classes and the following specifications m 1 = 10, m 2 = 7, m 3 = 8, n = 9 (30) We use one categorical distribution for every class. The categories of the categorical distributions describe the number of elements to sample from this class. Based on the three categorical distributions, we would like to sample x j ∈ N 0 such that j x j = n. The sequence of distributions ideally should describe the class conditional distributions of the hypergeometric distribution described in Section Or at least model sampling without replacement, i. e. fulfill the necessary constraint. We then have three vectors π 1 , π 2 and π 3 which define the categorical weights of the three categorical distributions. As such k π j,k = 1 ∀ j. Using three categorical distributions, we are not able to explicitly model ω, but if we are able to fulfill the constraints, there would be some matching ω. To make it differentiable, we approximate the categorical distributions using the Gumbel-Softmax (GS) trick, such that we can use everything in a differentiable pipeline. In most differentiable settings, the categorical weights are inferred using some neural network, e. g. π i = f ξ,i (•) Knowing that our categories correspond to integer values, we can use the straight-through estimator (Bengio et al., 2013) together with the Gumbel-Softmax trick and a bit of matrix multiplication to convert an one-hot vector to an integer value: y j =straight through(GS(π j )) (32) = arg max(GS(π j )) Please check the provided code for more details on the details of the matrix multiplications. We can either infer the categorical weights of and sample from the three distributions in parallel or sequentially. We first try to infer the weights of all distributions in parallel. With the constraint j x j = n, not all combinations of π j , ∀j are valid anymore. We might be able to learn the constraints, such that the number of samples fulfilling them is maximized. But it is not possible to have the constraints respected due to the formulation of this method. Hence, it is not guaranteed that we sample x j , ∀j such that j x j = n, which then results in non-valid samples. Therefore, we try to infer the π j sequentially. The sequential procedure models the following behaviour (similar to our proposed class conditional sampling) p(x 1 , x 2 , x 3 ) = p(x 1 )p(x 2 | x 1 )p(x 3 | x 1 , x 2 ) (34) Without loss of generality, we assume that we first infer π 1 and sample x 1 ∼ Cat(π 1 ) such that x 1 = 7. It follows that not all combination of weights π 2 and π 3 are valid anymore. When inferring π 2 some weights have to be zero to guarantee that j x j = n and sample a valid random vector. Hence, π 2,k = 0 ∀ k > 2 (35) If we assign any nonzero probability to π 2,k , k > 2, we are not able to fulfill the constraint j x j = n anymore for every generated sample. Additionally, from Equation ( 12) it follows that we would need to constrain the gumbel noise as well. We are unaware of previous work that proposed a constrained Gumbel-Softmax trick which would model this behaviour. Also, there is no guarantee that π 2,k > 0, k ≤ 2, leading to additional heuristics-based solution that we would need to implement. There is an unclear effect on the calculation of the gradients, which would need to be investigated. Also, for sampling and inferring of weights π j , there arise questions of ordering between classes and how valid random samples can be guaranteed. We summarize that the implementation and modeling of a hypergeometric distribution using a sequence of unconstrained categorical distributions is far from being trivial, because of the open question of how implement constraints in a general and dynamic way when using the Gumbel-Softmax trick. Note the important difference to our hypergeometric distribution. Although we also use the Gumbel-Softmax trick to generate random samples, we are able to infer ω in parallel, which does not introduce an implicit ordering between classes and the constraint j x j = n is guaranteed. The only constraint with respect to ω we have to satisfy, is ω j ≥ 0, ∀j. ω j ≥ 0 can easily be satisfied using a ReLU Algorithm 2 Subroutines for sampling From Multivariate Noncentral Hypergeometric Distribution. function SAMPLEUNCHG(m i , m j , ω i , ω j , n, τ ) α i ← calcLogPMF(m i , m j , ω i , ω j , n) # Section 4.2 x i , ri ← contRelaxSample(α i , τ )) # Section 4.3 return x i , α i , ri end function function CALCLOGPMF(m l , m r , ω l , ω r , n) for k ∈ {0, . . . , m l } do x l,k ← (k + 1) r,k ← (ReLU(n -k) + 1) end for l ← log Γ(x l + 1) + log Γ(m l -x l + 1) # see Appendix B.5 r ← log Γ(x r + 1) + log Γ(m r -x r + 1) α l ← x l log ω l + x r log ω r -(l + r) return α l end function function CONTRELAXSAMPLE(α l , τ ) u ← U (0, 1) g ← -log(-log u) rl ← α l + g p l ← Softmax(r l /τ ) x l ← Count-Index(Straight-Through(p l )) # Count-Index and Straight-Through: # see Appendix B.5 return x l , rl end function activation function. We are then able to calculate the class conditional distributions as a function of ω such that all constraints are satisfied. We do not need to dynamically change categorical weights, i. e. model parameters, during the sampling procedure, which is the case for the sequence of constrained categorical distributions.

B.5 ALGORITHM AND MINIMAL EXAMPLE

In the main text, we drafted the pseudocode for our proposed algorithm (see Algorithm 1). We only did it for the main functionality, but not the subroutines described in Sections 4.2 and 4.3). Algorithm 2 describes the subroutines used in Algorithm 1 and explained in Section 4. The straight-through operator uses a hard one-hot embedding in the forward path, but the relaxed vector for the backward path and the calculation of the derivative (Bengio et al., 2013) . The Count-Index maps the one-hot vector to an index, which is equal to the number of selected class elements in our case.

B.5.1 MINIMAL EXAMPLE

We describe a minimal example application using a step-by-step procedure to provide intuition and further illustrate the proposed method. Here, we learn the generative model of an urn model using stochastic gradient descent when given samples from an urn model with a priori unknown weights ω. Note that given a dataset, we could also estimate the weights ω of the hypergeometric distribution by minimizing the negative log probability of the data given the parameters log p X (x; ω). In contrast, we use a generative approach to demonstrate how our method allows backpropagation when modeling the generative process of the samples by reparameterizing the hypergeometric distribution. Additionally, we illustrate the minimal example with two figures (Figures 5 and 6 ), which explain the sampling procedure visually. Let us start with our minimal example. We are given a dataset of i.i.d samples X D ∈ N K×nc + from a multivariate noncentral distribution with unknown ω gt ∈ R c + . K is the number of samples in the m 1 = 3 N = 12 m 2 = 5 m 3 = 4 n = 5 Figure 5 : Illustration of the setting of the multivariate hypergeometric distribution. We have 3 classes of elements (green, orange, and blue) with different and unknown importance ω c for every class c. In our urn, the total number of elements N is given by the sum of elements of all classes m c . From this urn, we draw a group of n samples. In this example, n = 5. The group importance ω c is often unknown, and difficult to estimate. Our formulation helps to learn ω c using gradient-based optimization when simulating how given samples are drawn from the urn. X1 ∼ p X L (n, m L , m R , ω L , ω R ) n = 5 m L = m1 = 3 m R = m2 + m3 = 9 ω L = ω1 ω R = ω 2 m 2 +ω 3 m 3 m R N = 12 X2 ∼ p X L (n, m L , m R , ω L , ω R ) n = 4 m L = m2 = 5 m R = m3 = 4 ω L = ω2 ω R = ω3 N = 9 n = 1 m L = m3 = 4 m R = 0 ω L = ω3 ω R = 0 N = 4 X3 ∼ p X L (n, m L , 0, ω L , 0) Figure 6 : Illustration of the proposed conditional sampling from the multivariate noncentral hypergeometric distribution. We use the same urn as in Figure 5 with m = [3, 5, 4] and n = 5. As described in Section 4, we sequentially sample random variates of the individual classes. Hence, we start by sampling class 1. For that, we merge classes 2 and 3 (illustrated by the half blue and half orange balls) creating the necessary parameters m L , m R , ω L , ω R for p X L (•) (described in the left column). This is also described in Algorithms 1 and 2. Using the univariate distribution p X L (•) we sample the random variable X 1 , which is equal to 1 in our example (symbolized by the single green ball). We continue with sampling the class 2, which is described in the middle column. The merge operation simplifies to assigning m L = m 2 and m R = m 3 , and n is the original n minus X 1 . We draw X 2 = 3 in our example (again illustrated by the 3 orange balls below the urn). Because the number of drawn balls must sum to n, the last class X 3 is fully determined already. dataset and n c defines the number of classes. A sample from the dataset X D is denoted as X D . For every X D ∈ X , it holds that c X D,c = n. We assume that we know the total number of elements in the urn, e.g. m = [m 1 , m 2 , ..., m c ]. In our minimal example, we are interested in learning the unknown importance weights ω with a generative model using stochastic gradient descent (SGD). Hence, we assume a data generating distribution p X (x; ω) such that X ∼ p X (x; ω). The loss function L is given as L = X D ∈X D E X∼p X (x;ω) (X D -X) 2 (36) = X D ∈X D E X∼p X (x;ω) [L(X D , X)] ( ) where L is the loss per sample. p X (x; ω) is a noncentral multivariate hypergeometric distribution as defined in Definition 3.1 where the class importance ω is unknown. To minimize E[L(X)], we want to optimize ω. Using SGD, we optimize the parameters ω in an iterative manner: ω t+1 := ω t -η d dω E X∼p X (x;ω) [L(X D , X)] ( ) where η is the learning rate, and t is the step in the optimization process. Unfortunately, we do not have a reparameterization estimator d dω E[L(X D , X)] because of the jump discontinuities of the arg max function in the categorical distributions. As described in Sections 4.1 and 4.3, we can rewrite p X (x; ω) as a sequence of conditional distributions. Every conditional distribution is itself a categorical distribution, which prohibits the calculation of d dω E[L(X D , X)]. In more detail, we rewrite the joint probability distribution p X (x; ω) as p X (x; ω) =p X1 (x 1 ; ω) nc c=2 p Xc (x c | x 1 , ..., x c-1 ; ω) where every distribution p Xc (•; ω) is a categorical distribution. We sample every X c using Equation (5), i.e. p Xc (x Lc ; ω) =x Lc log ω Lc + (n c -x L ) log ω Rc + ψ F (x Lc ) + C ω Lc , ω Rc , m Lc , m Rc , and n c = j<c X j are calculated according to Equation (3) and sequentially for every class. The expected element-wise loss E X∼p X (x;ω) [L(X D , X)] changes to E X∼p X (x;ω) [L(X D , X)] =E X∼p X (x;ω) nc c=1 (X D,c -X c ) 2 (41) =E X∼p X (x;ω) nc c=1 L(X D,c , X c ) (42) = nc c=1 E X∼p X (x;ω) [L(X D,c , X c )] Hence, d dω E[L(X D , X)] = nc c=1 d dω E X∼p X (x;ω) [L(X D,c , X c )] Unfortunately, for every d dω E [L(X D,c , X c )], we face the problem of not having a reparameterizable gradient estimator. We cannot calculate the gradients of the loss directly, but p Xc (•) being categorical distributions allows us to use the Gumbel-Softmax gradient estimator (Jang et al., 2016; Maddison et al., 2014; Paulus et al., 2020) . The Gumbel-Softmax trick is a differentiable approximation to the Gumbel-Max trick (Maddison et al., 2014) , which provides a simple and efficient way to draw samples from categorical distribution with weights π. The Gumbel-Softmax trick uses the softmax function as a differentiable approximation to the argmax function used in the Gumbel-Max trick. It follows (Jang et al., 2016)  y =softmax((log π + g)/τ ) (45) =softmax τ (log π + g) where g 1 , . . . , g k are i.i.d. samples drawn from Gumbel(0, 1), and τ is a temperature parameter. y is a continuous approximation to a one-hot vector, i.e. 0 ≤ y i ≤ 1 such that i y i = 1. Different to the standard Gumbel-Softmax trick, we infer the weights π from the probability density function log p X (•) (see Equations ( 7) and ( 8)). We write for a single conditional class x Lc as the procedure is the same for all classes. It follows X τ,c (ω, g) = softmax τ (log p Xc (x Lc ; ω) + g) where x Lc = [0, m Lc ]. Due to the translation invariance of the softmax function, we do not need to calculate the constant C in log p Xc (x Lc ; ω). The Gumbel-Softmax approximation is smooth for τ > 0, and therefore E[L(X D , X τ )] has welldefined gradients d dω . We write the loss function to optimize and its gradients as E g [L(X D , X τ (ω, g))] = nc c=1 E g [L(X D,c , X τ,c (ω, g))] (48) d dω E g [L(X D , X τ (ω, g))] = nc c=1 E g d dω L(X D,c , X τ,c (ω, g)) By replacing the categorical distribution in Equation ( 5) with the Gumbel-Softmax distribution (see Lemma 4.2), we can thus use backpropagation and automatic differentiation frameworks to compute gradients and optimize the parameters ω (Jang et al., 2016) . We implemented our minimal example for n c = 3 classes. We set m = [m 1 , m 2 , m 3 ] = [200, 200, 200] and n = 180. We create 10 datasets X D ∈ N 1000×3 generated from different ω gt and show the performance of the proposed method. From these 1000 samples, we use 800 for training and 200 for validation. Similar to the setting we use for the KS-test (see Section 5.1), we choose 10 values for ω gt,2 , i.e. ω gt,2 = [1.0, 2.0, . . . , 10.0]. The values for ω gt,1 and ω gt,3 are set to 1.0 for all datasets versions. As described above the model does not have access to the data generating true ω gt . So, for every dataset X D , we optimize the unknown ω based on the loss L defined in Equation ( 43). Figure 7 shows the training and validation losses over the training steps. We train the model for 10 epochs, but we see that the model converges earlier. The losses only differ at the beginning of the training procedure, which is probably an initialization effect, but quickly converge to similar values independent of the ω 2 value that generated the dataset. Figure 8 shows the estimation of log ω. The x-axis shows the training step, the y-axis shows the estimated value. Figures 8a to 8b demonstrate that the hypergeometric distribution is invariant to the scale of ω. With increasing value of ω gt,2 , the values of ω 1 and ω 3 decrease, although their ground truth values ω gt,1 and ω gt,3 do not change. The training and validation loss do not increase though (Figures 7a and 7b ), which demonstrates the scale-invariance.

C EXPERIMENTS

All our experiments were performed on our internal compute cluster, equipped with NVIDIA RTX To highlight our accurate approximation, we provide more results from KS-tests here in the appendix. We perform additional tests with varying n and m 2 . See Figures 9 and 10 for the detailed results. We see that over all combinations, our approximation of the hypergeometric distribution performs well and produces samples of approximately the same quality as the reference distribution. C.2 WEAKLY-SUPERVISED LEARNING

C.2.1 METHOD, IMPLEMENTATION AND HYPERPARAMETERS

In this section we give more details on the used methods. We make use of the disentanglement lib (Locatello et al., 2019) which is also used in the original paper we compare to (Locatello et al., 2020) . The baseline algorithms (Bouchacourt et al., 2018; Hosoya, 2018) are already implemented in disentanglement lib. For details on the implementation of models, we refer to the original paper. We did not change any hyperparameters or network settings. All experiments were performed using β = 1.0 as this is the best performing β according to Locatello et al. (2020) . For all experiments we train three models with different random seeds. All experiments are performed using GroupVAE (Hosoya, 2018) . Using GroupVAE, shared latent factors are aggregated using an arithmetic mean. Bouchacourt et al. (2018) assume also knowledge about shared and independent latent factors. In contrast to GroupVAE, their ML-VAE aggregates shared latent factors by using the Product of Experts (i.e. geometric mean). Figure 12 shows the basic architecture. The different architectures only differ in the View Aggregation module. In this module, every method selects the latent factors z i ∈ S, which should be aggregated over different views x 1 and x 2 . Given a subset S of shared latent factors, it Figure 11 : As an additional evaluation, we perform analysis on how disentangled the latent representations are. And related to that, we assess the quality of the learned latent representation using a linear classifier. We see that the dynamics over different k seems to be related for the disentantlement and downstream performance of the learned latent representations. E D x 1 x 2 x1 x2 View Aggregation q ϕ (z | x 1 ) q ϕ (z | x 2 ) Figure 12 : Setup for the weakly-supervised experiment. The three methods differ only in the View Aggregation module. follows q ϕ (z i | x j ) =avg(q ϕ (z i | x 1 ), q ϕ (z i | x 2 )) ∀ i ∈ S (50) q ϕ (z i | x j ) =q ϕ (z i | x j ) else ( ) where avg is the averaging function of choice as described above and j ∈ {1, 2}. The methods used (i. e. LabelVAE, AdaVAE, HGVAE) differ in how to select the subset S.

C.2.2 HYPERGEOMETRICVAE (IN MORE DETAIL)

In our approach (HGVAE), we model the number of shared and independent latent factors of a pair of images as discrete random variables following a hypergeometric distribution with unknown ω. In reference to the urn model, shared and independent factors each correspond to one color and the urn contains d marbles of each color, where d is the dimensionality of the latent space. Given the correct weights ω and when drawing from the urn d times, the number of each respective color corresponds, in expectation, to the correct number of independent/shared factors. The proposed formulation allows to simultaneously infer such ω and learn the latent representation in a fully differentiable setting within the weakly-supervised pipeline by Locatello et al. (2019) . To integrate the procedure described above in this framework, we need two additional building blocks. First, we introduce a function that returns log ω. To achieve this, we use a single dense layer which returns the logits log ω. The input to this layer is a vector γ containing the symmetric version of the KL divergences between pairs of latent distributions, i.e. for latent P and Q, the vector contains 1 2 (KL(P ||Q) + KL(Q||P )). Second, sampling from the hypergeometric distribution with these weights leads to an estimate k and ŝ. Consequently, we need a method to select k factors out of the d that are given. Similar to the original paper, we select the factors achieving the highest symmetric KL-divergence. In order to do this, we sort γ in descending order using the stochastic sorting procedure neuralsort (Grover et al., 2019) . This enables us to select the top k independent as well as the bottom ŝ = d -k shared latent factors. Like AdaVAE, we substitute the shared factors by the mean value of the original latent code before continuing the VAE forward pass in the usual fashion.

C.2.3 HYPERPARAMETER SENSITIVITY

We perform ablations to examine how sensitive HypergeometricVAE is to certain hyperparameters. We find that the temperature and the learning rate have the most influence on training stability and convergence. First, if we set the temperature τ too low, we often observe vanishing or exploding gradients. This well-known artifact of the Gumbel-Softmax trick can be avoided by introducing temperature annealing. Hence, similar to the original Gumbel-Softmax trick (Jang et al., 2016; Maddison et al., 2017) and the neuralsort implementation (Grover et al., 2019) , we anneal the temperature τ using an exponential function τ t = τ init exp(-rt) ( ) where t is the current training step, τ init is the initial temperature, and r is the annealing rate: r = log τ init -log τ f inal n steps τ f inal is the final temperature value and n steps is the number of annealing steps. As shown in Figure 13 , training loss and shared factor estimation then converge almost independently of the final temperature. In our final experiments, we use identical temperatures τ for both the differentiable . The final temperature had minimal impact on stability and convergence, whereas higher learning rates led to some instabilities. hypergeometric distribution and neuralsort. We set the initial temperature τ init to 10 and the final temperature τ f inal to 0.01, which is annealed over n steps = 50000. Further, we find the learning rate to be the most crucial hyperparameter in terms of convergence. Higher learning rates generally seem to lead to worse training losses. On the other hand, estimating the number of shared factors seems robust on average, although high standard deviations imply decreasing consistency for higher learning rates. We demonstrate this finding in Figure 13 . We used an initial learning rate of 10 -6 together with the Adam optimizer (Kingma & Ba, 2014) for our final experiments. Finally, we also experimented with weighting the KL-divergence with a β term like in the β- VAE Higgins et al. (2017) , where we did not find an influence on stability and convergence. Hence, we left it at a default value of 1 in our experiments.

C.2.4 DATA

The mpi3d dataset (Gondal et al., 2019) consists of frames displaying a robot arm and is based on 7 generative factors: • object color, shape and size • camera height • background color • horizontal and vertical axis Table 3 : We report the runtimes for the three methods used in the weakly-supervised experiment, i. e. labelVAE, adaptiveVAE and hypergeometricVAE. We report the runtime as mean and standard deviation over all runs per experiment. We report the runtimes for the different number of independent factors k = {-1, 2, 4, 6}. For more details on the dataset and in general, we refer to https://github.com/ rr-learning/disentanglement_dataset.

C.2.5 DOWNSTREAM TASK ON THE LEARNED LATENT REPRESENTATIONS

For the downstream task we sample randomly 10000 samples from the training set and 5000 samples from the test set. For each sample, we extract the predicted shared and the predicted independent parts of both views. Then, for every generative factor of the dataset, three individual classifiers are trained on the respective latent representations of the 10000 training samples. Afterwards, every classifier evaluates its predictive performance on the latent representations of the 5000 test samples. To arrive at the final scores, we extract the prediction of the shared factors on the shared representation and compute the balanced accuracy. Similarly, we calculate the balanced accuracy of the independent factors on the respective independent representation classifiers and average their balanced accuracy. Because the number of classes differs between generative factors we report the adjusted balanced accuracy. We use the implementation from scikit-learn (Pedregosa et al., 2011) . For details, see https://scikit-learn.org/stable/modules/generated/ sklearn.metrics.balanced_accuracy_score.html. For all shared generative factors, we average the accuracies of the individual classifiers into a single average balanced accuracy. We do the same for the independent factors. This allows us to report the amount of shared and independent information that is present in the learned latent representation. Consequently, we report these averages in the main text. To evaluate the latent representation we train linear classifiers. More specifically, in this work we use logistic regression classifiers (Cox, 1958 ) from scikit-learn (Pedregosa et al., 2011) . To train the model, we increased the max iter parameter so that all models converged and left everything else on default settings.

C.2.6 RUNTIMES OF DIFFERENT ALGORITHMS

In general, our sampling method scales with the number of classes O(c) and the sampling for a single class scales with O(m i ), which is calculated in parallel and a single forward pass. The get a better and empirically validated picture of the overhead created by using the hypergeometric distribution, we also report the training times for the three methods compared. All methods were trained for an equal number of epochs and on identical hardware equipped with NVIDIA GeForce GTX 1080. Table 3 reports the runtimes for all methods and averaged over 5 runs. We see that the overhead of using the hypergeometric distribution is almost negligible. The proposed hypergeometricVAE reaches approximately equal training runtimes as the labelVAE, which assumes that the number of shared factors is known. It even outperforms the adaptiveVAE, which uses a very simple heuristic, but needs to result to more sophisticated methods in order to avoid cutting gradients in their thresholding function. 

C.3.2 IMPLEMENTATION DETAILS

To implement our model we adopted a feed-forward architecture for both the encoder and decoder of the VAE with four layers of 500, 500, 2000, D units respectively, where D = 10. The VAE is pretrained using the same layer-wise pretraining procedure used by Jiang et al. (2016) . Each data set is divided into training and test sets, and all the reported results are computed on the latter. We employed the same hyper-parameters for all experiments. In particular, the learning rate is set to 0.001, the batch size is set to 128 and the models are trained for 1000 epochs. Additionally, we used an annealing schedule for the temperature of the Gumbel-Softmax trick. As the VaDE is rather sensitive to initialization, we used the same pretraining weights provided by Jiang et al. (2016) . These weights have been selected by the baseline to enhance the performance of their method, leading to an optimistic outcome. If a random initialization were used instead, the clustering performance would be lower. Nonetheless, the focus of our work is on comparing methods rather than on absolute performance values. Figure 14 displays the general architecture of all methods. The methods only differ in their definition of the prior probability distribution p(c; π).



Huijben et al. (2021) provide a great review article of the Gumbel-Max trick and its extensions describing recent algorithmic developments and applications. E. g. Tensorflow(Abadi et al., 2016) or PyTorch(Paszke et al., 2019) The code can be found here: https://github.com/thomassutter/mvhg Although the distribution itself is older,Gonin (1936) were the first to name it hypergeometric distribution see Appendix B.1 for the definition on the Gamma function



Figure 1: Comparing random variables from the proposed differentiable formulation to a nondifferentiable reference implementation. We draw samples from a multivariate noncentral hypergeometric distribution consisting of three classes. m i = 200 ∀i and n = 180. The class weights ω 1 and ω 3 for classes 1 and 3 are set to 1.0, ω 2 is increased from 1.0 to 10.0 with a step size of 1.0 (w2 in the figure). Figure 1a shows histograms of the number of elements for class 2.Figure 1b represents the calculated distance values of the KS test between the reference and proposed implementation (upper plot) and their respective p-values (lower plot).

Figure 1: Comparing random variables from the proposed differentiable formulation to a nondifferentiable reference implementation. We draw samples from a multivariate noncentral hypergeometric distribution consisting of three classes. m i = 200 ∀i and n = 180. The class weights ω 1 and ω 3 for classes 1 and 3 are set to 1.0, ω 2 is increased from 1.0 to 10.0 with a step size of 1.0 (w2 in the figure). Figure 1a shows histograms of the number of elements for class 2.Figure 1b represents the calculated distance values of the KS test between the reference and proposed implementation (upper plot) and their respective p-values (lower plot).

) 0.18±0.01 0.22±0.05 0.19±0.01 0.08±0.02 0.28±0.01 0.28±0.01 0.01±0.00

Figure 2: We report the mean squared error MSE(s, ŝ) between true s and estimated ŝ number of shared factors to assess the models' performance.

Figure 3: True class (■) vs. learned hypergeometric cluster weights π i (■)

Figure 4: Comparing random variables drawn from the proposed distribution to a reference distribution.

Figure 7: Training and validation losses for different values of ω gt of our minimal example described in Appendix B.5.1.

(a) Estimation log ω1 (b) Estimation of log ω2 (c) Estimation of log ω3

Figure 8: The estimated log ω values over training procedure for different ground truth ω gt values of our minimal example (see Appendix B.5.1). We see nicely that the hypergeometric distribution is invariant to the scale of ω. With increasing value of ω 2 , the estimated values for ω 1 and ω 3 change as well, but the training and validation loss remain low (see Figure7).

(a) Histograms class 1 (b) Histograms class 2 (c) Histograms class 3 (d) Distance and p-values

Figure 9: Sensitivity analysis for a varying number of samples draw n measured with the Kolmogorov-Smirnov test. In this experiment, we have the following specifications: m = {200, 200, 200}, ω = {1.0, 5.0, 1.0}. The values of n are given in the plot.

Figure 10: Sensitivity analysis for a varying total number of elements m measured with the Kolmogorov-Smirnov test. In this experiment, we have the following specifications: n = 200, ω = {1.0, 5.0, 1.0}, m 1 = 200 and m 3 = 200. The values of m 2 are given in the plot.

(a) Average train loss when varying τ f inal . (b) Average train loss when varying the learning rate. (c) Mean squared error of estimating the number of shared factors when varying τ f inal . (d) Mean squared error of estimating the number of shared factors when varying the learning rate.

Figure 13: Ablation of varying the final temperature τ f inal (left) and the learning rate (right) of the HGVAE in the weakly supervised experiment. We made ablations for the training loss (top) and mean squared error of estimating the number of shared factors (bottom). The final temperature had minimal impact on stability and convergence, whereas higher learning rates led to some instabilities.

273.0 20921.1 ± 705.9 21389.3 ± 163.8 21761.1 ± 567.9 ADAPTIVE 29071.5 ± 133.2 28609.8 ± 439.9 29479.1 ± 487.1 29966.3 ± 303.3 HYPERGEOMETRIC 21888.9 ± 632.9 21299.4 ± 293.6 21863.9 ± 190.1 22241.0 ± 137.8

Figure 14: General Architecture for the clustering experiments. All methods have the same architecture details. They only differ in their definition of the prior distribution p(c; π).

.

Seiji Ono, Kazuharu Misawa, and Kazuki Tsuji. Effect of group selection on the evolution of altruistic behavior. Journal of theoretical biology, 220(1):55-66, 2003. Mario Paolucci, Rosaria Conte, and Gennaro Di Tosto. A model of social organization and the evolution of food sharing in vampire bats. Adaptive Behavior, 14(3):223-238, 2006. Adam Paszke, Sam Gross, Francisco Massa, Adam Lerer, James Bradbury, Gregory Chanan, Trevor Killeen, Zeming Lin, Natalia Gimelshein, Luca Antiga, Alban Desmaison, Andreas Köpf, Edward Z Yang, Zach DeVito, Martin Raison, Alykhan Tejani, Sasank Chilamkurthy, Benoit Steiner, Lu Fang, Junjie Bai, and Soumith Chintala. PyTorch: An Imperative Style, High-Performance Deep Learning Library. CoRR, abs/1912.0, 2019. Max B Paulus, Dami Choi, Daniel Tarlow, Andreas Krause, and Chris J Maddison. Gradient estimation with stochastic softmax tricks. arXiv preprint arXiv:2006.08063, 2020. Fabian Pedregosa, Gaël Varoquaux, Alexandre Gramfort, Vincent Michel, Bertrand Thirion, Olivier Grisel, Mathieu Blondel, Peter Prettenhofer, Ron Weiss, Vincent Dubourg, and others. Scikit-learn: Machine learning in Python. Journal of machine learning research, 12(Oct):2825-2830, 2011. Danilo Jimenez Rezende, Shakir Mohamed, and Daan Wierstra. Stochastic backpropagation and approximate inference in deep generative models. In International Conference on Machine Learning, pp. 1278-1286. PMLR, 2014.

2080 and NVIDIA RTX 1080. Every training and test run used only a single NVIDIA RTX 2080 or NVIDIA RTX 1080. The weakly-supervised runs took approximately 3-4 hours each, where the clustering runs only took about 3 hours each. We report detailed runtimes for the weakly-supervised experiments in Table3to highlight the efficiency of the proposed method.

ACKNOWLEDGMENTS

We would like to thank Ričards Marcinkevičs for good discussions and help with the Kolmogorov-Smirnov-Test. TS is supported by the grant #2021-911 of the Strategic Focal Area "Personalized

C.3 CLUSTERING

In this section, we provide the interested reader with more details on the clustering experiments. In the following we describe the models, the optimisation procedure and the implementation details used in the clustering task.

C.3.1 MODEL

We follow a deep variational clustering approach as described by Jiang et al. (2016) .Given a dataset X = {x i } N i=1 that we wish to cluster into K groups, we consider the following generative assumptions:where c = {c i } N i=1 are the cluster assignments, z i ∈ R D are the latent embeddings of a VAE and x i is assumed to be binary for simplicity.In particular, assuming the generative process described in Equation ( 54), we can write the joint probability of the data, also known as the likelihood function, asDifferent from Jiang et al. (2016) , the prior probability p(c; π) cannot be factorized as p(c i ; π) for i = 1, . . . , K are not independent. By using a variational distribution q ϕ (Z, c|X), we have the following evidence lower boundFor sake of simplicity, we assume the following amortized mean-field variational distribution, as in previous work (Jiang et al., 2016; Dilokthanakul et al., 2016) :From where it followsIn the ELBO formulation all terms, except the first one, can be efficiently calculated as in previous work (Jiang et al., 2016) . For the remaining term, we rely on the following sampling schemewhere we use the SGVB estimator and the Gumbel-Softmax trick (Jang et al., 2016) to sample from the variational distributions q ϕ (z i |x i ) and q ϕ (c i |x i ) respectively. The latter is set to a categorical distributions with weights given by:L is the number of Monte Carlo samples and it is set to 1 in all experiments.

