COMPLETE LIKELIHOOD OBJECTIVE FOR LATENT VARIABLE MODELS

Abstract

In this work, we propose an alternative to the Marginal Likelihood (MaL) objective for learning representations with latent variable models, Complete Latent Likelihood (CoLLike). We analyze the objectives from the perspective of matching joint distributions. We show that MaL corresponds to a particular KL divergence between some target joint distribution and the model joint. Furthermore, the properties of the target joint explain such major malfunctions (from the representation learning perspective) of MaL as uninformative latents (posterior collapse) and high deviation of the aggregated posterior from the prior. In the CoLLike approach, we use a sample from the prior to construct a family of target joint distributions, which properties prevent these drawbacks. We utilize the complete likelihood both to choose the target from this family and to learn the model. We confirm our analysis by experiments with low-dimensional latents, which also indicate that it is possible to achieve high-accuracy unsupervised classification using CoLLike objective.

1. INTRODUCTION

In the latent variable setting, the model defines a joint distribution over both observed variables x and latent variables z, while the training data contains only observed variables. The problem can be treated as an unknown z|x target conditional distribution. There are at least two possible solutions to this problem: try to come up with a meaningful target z|x distribution and train the model similarly to a supervised setting, or give up and focus on matching only marginals in the x domain. The latter is the choice of the MaL objective. In this work, we follow the former approach. However, instead of picking up a single target conditional we construct an entire family of possible distributions and use the model likelihood to decide which conditional to use as a target. To construct a family of possible conditionals, we use a sample from prior of the same size as the dataset in the observed domain. All possible assignments of observed samples to latent ones span a family of empirical joint distributions. This can be represented as permutations of the latent samples. Despite the size of the permutations set being tremendous and growing as a factorial of the dataset size, the search of the permutation with the best likelihood can be done efficiently using combinatorial optimization. The resulting optimization procedure resembles the expectation maximization algorithm (Dempster et al., 1977) , where expectation is replaced with the combinatorial assignment problem. Furthermore, since the proposed algorithm uses gradient-free optimization for obtaining the target distribution, the objective can be seamlessly applied to both continuous and discrete latent variables, while the discrete latents case is challenging for approaches based on the MaL (Mnih & Gregor, 2014; Mnih & Rezende, 2016; Tucker et al., 2017) . We analyze the objectives from the perspective of matching joint distributions. We show that MaL corresponds to a specific choice of the target z|x conditional, while our approach takes into consideration family of possible conditionals. The choice of target conditional is responsible for two major failures that arise during training with the MaL objective: inability to learn informative latents, also known as "posterior collapse" (Bowman et al., 2016; Razavi et al., 2019; He et al., 2019) , and divergence between the prior and the aggregated posterior (Hoffman & Johnson, 2016; Makhzani et al., 2015; Zhao et al., 2019; Kim & Mnih, 2018) . These characteristics are vital for latent variable models because posterior collapse prevents learning meaningful representation and sampling from the regions of high deviation of the latent marginals are subjected to severe quality degradation (Rosca et al., 2018) . The form of the target joint also motivates the success of the complete likelihood in these challenges. Namely, the target distribution for CoLLike has high mutual information and matches prior. We verify our analysis with experiments. In this work, we focus on low-dimensional latent variables to perform a direct comparison with the exact MaL. Models trained with CoLLike stably maintain high mutual information and low divergence from the prior. In turn, MaL inevitably leads either to posterior collapse or to a highly divergent aggregated posterior. Previously, for simple linear models, it has been shown that there is posterior collapse during the optimization of the exact likelihood Lucas et al. (2019) . Our experiments demonstrate that it can as well happen with expressive models trained with exact likelihood. Along with informativeness and latent distribution matching, CoL-Like indicates no degradation of likelihood compared to MaL. Furthermore, we show that CoLLike objective alone can achieve high accuracy in unsupervised classification. We show that CoLLike unifies a range of existing approaches that lack probabilistic justification. Constrained K-means (Bennett et al., 2000) , Permutation Invariant Training (Yu et al., 2017; Luo & Mesgarani, 2019) , and Noise as Target Bojanowski & Joulin (2017) are among these approaches. This allows us to extend them to different factorizations of the joint and perform analysis from the probabilistic perspective. Furthermore, CoLLike bridges likelihood and optimal transport (OT) frameworks. From this perspective, the negative likelihood plays the role of both mapping from latent to visible domain and distance function.

2. COMPLETE LIKELIHOOD OBJECTIVE

In the regular latent variable setting, we are given a dataset {x 1 , ..., x N } and the model p θ (x, z) = p θ (x|z)p(z). The missing z can be treated as the missing p δ (z|x) part of the target joint. If we cannot come up with a reasonable z|x target, we can at least match the marginals in the observed domain with KL(p δ (x)||p θ (x)) in hope that the model will learn an informative relation between x and z. This is equivalent to the maximization of MaL: L M aL (θ) = N i=1 log p θ (x i ) = N i=1 log p θ (x i , z)dz (1) Justification of the MaL comes from the equivalence of maximization of (1) and minimization of the Kullback-Leibler divergence KL(p δ (x)||p θ (x)) which measures the discrepancy between the target empirical data distribution p δ (x)foot_0 and the model distribution p θ (x) (Murphy, 2022, 4.2.2) . Note that this justifies MaL only for learning distributions of observed variables, not learning representations. The fact that MaL does not promote informativeness (Alemi et al., 2018) clearly shows the lack of justification of MaL for learning representation because informativeness is undoubtedly a fundamental requirement for any useful representation. Despite the family of all possible target p(z|x) distributions being tremendous, we do not need to consider it entirely. Firstly, the target distribution must be informative. Secondly, the marginal of the target joint distribution in the latent domain should match the prior p(z). The fixed prior implies that the desired marginal distribution of z is known. These requirements can be interpreted (Huszár, 2017) as Infomax principle (Linsker, 1988) . It is not hard to get a rich family of distributions with such properties. We can obtain a collection (z 1 , ..., z N ) by sampling from the prior and pair this collection with the dataset (x 1 , ..., x N ). The pairing produce an empirical distribution. Empirical distribution attains highest possible Mutual Information (MI) under the assumption that there are no repeated values of x in the dataset (see Appendix B for derivation). This ensures the first requirement. Sampling from the prior addresses the second requirement, since the collection of z samples converges to p(z) (Cover & Thomas, 2006, Theorem 11.2.1) . However, the sampling effects can be a significant problem for high-dimensional latents. We express each pairing as some permutation π, which produces a complete collection ((x 1 , z π(1) ), ..., (x N , y π(N ) )) and an empirical joint p δπ (x, z) = p δ (x)p π (z|x). Given a family of distributions, we need to decide which member of the family is our target. We propose to pick the one with the highest complete likelihood relying on the model inductive biases. For this target we then once again optimize the complete likelihood of the (x i , z π * (i) ) pairs with the optimal permutation π * . These considerations lead us to the CoLLike objective: L CL (θ, π) = N i=1 log p θ x i , z π(i) which we maximize both with respect to θ and π. An alternative view on the objective can be the following: we sample z values from prior and assume that they are ground truth targets for the training dataset with unknown pairing. Figure 1 depicts the main difference between the objectives: CoLLike maximizes specific points of the joint distribution, while MaL is aimed at maximization of whole lines along the joint. Bold lines and double circles are areas of the joint to be maximized.

3. OBJECTIVE ANALYSIS

We start our analysis by proving that MaL corresponds to the matching of a specific joint distribution and the model joint: KL(p δ (x)p θ (z|x)||p θ (x, z)) = E x,z∼p δ (x)p θ (z|x) log p δ (x)p θ (z|x) p θ (x)p θ (z|x) = E x∼p δ (x) log p δ (x) p θ (x) = E x∼p δ (x) [log p δ (x)] -E x∼p δ (x) [log p θ (x)] = C - 1 N i log p θ (x i ) = C - 1 N L M aL (θ) where C is a constant. The joint KL form of the MaL brings new perspectives on the objective. It might be tempting to think about MaL as a workaround for unknown latents that allows you not to specify the target z|x conditional. However, the joint form reveals that the target conditional is actually specified and equals p θ (z|x) if we ask what distribution we want to mimic. This implies that we are aiming to keep the model posterior unchanged. In addition, the form also highlights the intimate connection between MaL and posterior. CoLLike and a common variational (Jordan et al., 1999) approximation of MaL, Evidence Lower Bound (ELBO), can also be expressed as KL divergences between joint distributions (see Table 1 ). We refer to Appendix A for derivation of the equivalence. Note the elegant similarity between objectives which becomes obvious in the joint KL form. All divergences share the model p θ (x, z) as the second argument, which implies that the first argument is the target joint distribution. For all objectives the target joint contains the data distribution p δ (x) as a marginal in x domain, thus the only difference is in the target z|x conditional. Therefore, all the considered objectives belong to the family of the following form: L(θ) = KL(p δ (x)p(z|x)||p θ (x, z)) = KL(p δ (x)p(z|x)||p θ (x)p θ (z|x)) = KL(p δ (x)||p θ (x)) + E p δ (x) [KL(p(z|x)||p θ (z|x))] Since the second term in (3) is non-negative, all objectives in the family are lower bounds on the likelihood up to an additive constant. Note that the z|x target conditional is used to minimize the overall divergence. This affects the second term of (3) to make the lower bound tighter. Table 1 : Considered objectives and their joint KL forms.

Original Objective Joint KL form

CoLLike N i=1 log p θ x i , z π(i) KL(p δ (x)p π (z|x)||p θ (x, z)) MaL N i=1 log p θ (x i ) KL(p δ (x)p θ (z|x)||p θ (x, z)) ELBO 2 N i=1 E z∼q ϕ (z|xi) log p θ (xi,z) q ϕ (z|xi) KL(p δ (x)q ϕ (z|x)||p θ (x, z)) Despite the common traits, the objectives are different. We will highlight a few differences and go deeper in the following sections. Firstly, the target conditional for CoLLike p π (z|x) is empirical, while its counterparts p θ (z|x) and q ϕ (z|x) are not. Secondly, in MaL approach, we construct a particular joint distribution p δ (x)p θ (z|x) and use it as a target joint, while, in CoLLike, we construct an entire family of joint distributions with desired properties. Thirdly, the target posterior is readily available in CoLLike and ELBO cases, while for MaL it could be intractable. Furthermore, CoLLike can be seamlessly applied to discrete variables, while optimization of ELBO for discrete latents is challenging (Mnih & Gregor, 2014; Mnih & Rezende, 2016; Tucker et al., 2017) . Lastly, the CoLLike objective allows learning models with a reverse factorization p θ (x)p θ (z|x), while MaL and ELBO do not. Reverse factorization is another inductive bias that can be useful or not. Furthermore, it can be significantly faster compared to regular factorization if p θ (x) is assumed to be uniform and p(z|x) is factorized.

3.1. MUTUAL INFORMATION OF THE TARGET DISTRIBUTION

Mutual Information is the key property of the joint distribution in a latent variable setting. It characterizes how dependent the observed and latent variables are. We would like to know what MI value our model is targeted at for each objective. Since our objective can be expressed as KL divergence between model and target joint distributions (Table 1 ), we can investigate MI values for each target joint. We define MI between x and z under p(x, z) distribution as: M I(p(x, z)) = E x,z∼p(x,z) log p(x, z) p(x)p(z) For MaL, the MI of the target p δ (x)p θ (z|x) is determined by the model's current posterior p θ (z|x). Most models have no class preferences at initialization, which results in low MI of p δ (x)p θ (z|x). Moreover, we are aimed at keeping it unchanged, since we are using the current posterior as our target posterior. So, low MI at initialization might induce learning non-meaningful factorized joint throughout the training procedure. Since for ELBO the approximate posterior aligns to the true model posterior this argument is applicable to ELBO too. Furthermore, uninformative posterior is a common problem when learning a latent variable model (Bowman et al., 2016; Alemi et al., 2018; Lucas et al., 2019; Razavi et al., 2019; He et al., 2019) known as "posterior collapse". CoLLike target is an empirical joint distribution. It represents a deterministic mapping and has constantly high MI by construction, as shown in Appendix B. Therefore, we are aimed at mimicking a high MI distribution with our model distribution. Furthermore, CoLLike can be interpreted as some realization of InfoMax principle Huszár (2017), where prior limits the entropy and deterministic mapping maximize MI.

3.2. MATCHING IN THE LATENT DOMAIN

The joint form of the objectives from Table 1 is convenient for obtaining a perspective on distribution matching in the latent space. After treating p δ (x)p θ (z|x) as a joint p δθ (x, z) and rewriting the original MaL objective as: KL(p δ (x)||p θ (x)) = KL(p δθ (x, z)||p θ (x, z)) = E x,z∼p δθ (x|z)p δθ (z) log p δθ (x|z)p δθ (z) p θ (x|z)p θ (z) = E z∼p δθ (z) [KL (p δθ (x|z)||p θ (x|z))] + KL(p δθ (z)||p θ (z)) we see that matching in x space requires matching in z space. Namely, KL(p δθ (z)||p θ (z)) = 0, where p δθ (z) is called an aggregated posterior. It signifies that even though MaL is constructed such that z given x conditional part of the KL between joints is zero, we end up in a situation where none of the model marginals match target marginals. Moreover, the learning signal from the first term of (5) might be significantly larger compared to the second term signal if the dimensions of x and z differ a lot. This might lead to a sacrifice of the second divergence in favor of the first one. Matching in a latent domain is considered as a known challenge of latent variable modelling (Hoffman & Johnson, 2016) . Mismatch with prior results in unnatural samples from areas with high deviation of aggregated posterior from the prior (Rosca et al., 2018) . A number of works is focused on this problem. They either utilize additional losses that penalize discrepancy between marginals (Makhzani et al., 2015; Zhao et al., 2019; Kim & Mnih, 2018) or introduce a learnable prior (Bauer & Mnih, 2019; Tomczak & Welling, 2018) . In turn, CoLLike addresses this problem by constructing a conditional, which marginal matches prior in the latent domain. Obviously, the target marginal in x domain for CoLLike is always p δ (x). In turn, the target aggregate posterior is always a sample from the prior since p δπ (z) = x p δ (x)p π (z|x)dx = p ϵ (z) for all π values, where p ϵ (z) is the distribution of the sample produced by sampling from the prior. While it is intuitively obvious that the empirical distribution converges to the underlying distribution, one can show that KL between the empirical sample and the prior converges in probability to 0 (Cover & Thomas, 2006, Theorem 11.2.1) .

4. ALGORITHM

The objective (2) includes maximization with respect to two parameters: π and θ. We approach it by alternatingfoot_2 between maximization with respect to π and θ. We apply stochastic minibatch technique similar to Bojanowski & Joulin (2017) , which performs maximization of both π and θ for a minibatch instead of the entire dataset and returns the latents back to the dataset in the optimal order. Furthermore, we interpret maximization with respect to π as a linear sum assignment problem (LAP) to utilize efficient combinatorial optimization techniques (see Appendix D for the derivation). Algorithm 1 describes the resulting stochastic optimization procedure.

Algorithm 1 Stochastic optimization of CoLLike

Require: X = (x 1 , ..., x N ), p θ (x, z) = p θ (x|z)p(z), batch size B, learning rate η Sample Z = (z 1 , ..., z N ) from prior (z i ∼ p(z)) while not converged do Sample random indices (i 1 , ..., i B ) Compute matrix C ∈ R B×B , for which C q,k = log p θ (x iq , z i k ) Compute π * that maximizes (2) for (x i1 , ..., x i B ), (z i1 , ..., z i B ) by applying LAP solver to C θ ← θ -η∇ θ L CL (θ, π * ) Put (z i1 , ..., z i B ) back into Z in the optimal order (z π * (i1) , ..., z π * (i B ) ) end while The core of the algorithm is in computation of the matrix C and optimal permutation π * . Both parts are potentially computationally intense and challenging. Computation of the matrix C requires B 2 forward passes. Note that backward passes are not required for this step, hence, memory requirements are mild. Furthermore, in this work we focus on low-dimensional discrete latents. Assuming the number of categories K of the latent variable z is lower than B, the number of all possible values of z in log p θ (x iq , z) equals K instead of B. Thus, we can calculate all needed values of C using only K • B forward passes instead of B 2 . In the supervised case, we only need the θ update part of the entire algorithm, which requires B forward passes and B backward passes. As a rough estimate, we can assume that forward and backward passes take the same time, CoLLike will then require K/2 more compute time compared to the supervised setting. To find π * we use Hungarian algorithm Kuhn (1955) as a LAP solver. The algorithm requires the cost matrix C ∈ R B×B as input to produce the optimal permutation π * in the form of optimal ordering of (z i1 , ..., z i B ) indices. The complexity of the algorithm is O(B 3 ). The complexity of the LAP solver potentially limits the applicability of CoLLike to large batch sizes. However, for batch sizes regularly used in practice, solving LAP results in only a minor increase in the overall computation time. For instance, in our experiments, we used batch size 64. Solving the LAP took orders of magnitude less time compared even to the supervised setting. See Appendix E for detailed timings. We also highlight that optimization with large batches is not only challenging but also could significantly reduce generalization (Xing et al., 2018; You et al., 2019; 2020) . However, as gracefully shown by Huszár (2017), this kind of minibatch combinatorial optimization provides only locally optimal solutions. Nevertheless, the size of the gap between local global optimum is still to be determined. The result of Algorithm 1 is a trained model. However, we are interested in the posterior p θ (z|x). For low-dimensional categorical z, we can exactly compute the posterior using Bayes rule p θ (z|x) = p θ (x, z)/p θ (x) since p θ (x) = i p θ (x, z = i) is tractable. For other cases, we can fit an approximate posterior using regular variational techniques. We can also use CoLLike objective to obtain estimates of z values. After training model with CoLLike objective we have X and Z arrays that are matched. We add new x samples to X and sample extra z values from the prior to extend Z. The inference can then be performed by optimizing 2 with respect to π.

5. CONNECTIONS

Connections with existing techniques not only give alternative perspectives on CoLLike objective, but also provide probabilistic grounding to some existing algorithms. Many well-known objectives actually use CoLLike while being motivated as an ad-hoc empirical risk minimization. We show that these objectives not only seem reasonable but are also probabilistically motivated. While traditional K-means algorithm (MacQueen, 1967; Lloyd, 1982) has a probabilistic grounds (Murphy, 2022, 21.4.1.1) , its constrained counterpart (Bennett et al., 2000) lacks probabilistic justification. Constrained K-means is equivalent to CoLLike under factorized Gaussian p θ (x|z) and uniform categorical p(z), which has a number of states equal to the number of clusters. This connection allows extending the constrained K-means approach to different generative distributions and priors. Nevertheless, a probabilistic interpretation is present in Jitta & Klami (2018) , however, the choice of the complete likelihood as an objective is not explained. Permutation Invariant Training (PIT) (Yu et al., 2017; Luo & Mesgarani, 2019) used in source separation solutions can also be expressed as CoLLike objective. For instance, in cocktail part problem, we want to separate a mixture of K sources. During training, we have K isolated mixture components and a network that produces K estimates of the components based on a single mixture. We don't know which network output corresponds to which source and we pick a permutation that produces minimal total mismatch between outputs and sources. This procedure corresponds to training a latent variable model with CoLLike objective, where a categorical latent variable of dimension K determines the source identity. In this setting, we treat mixture components as samples in the dataset. The closest predecessor of the CoLLike is Noise As Target (NAT) (Bojanowski & Joulin, 2017) . This is an unsupervised approach to learn an image encoder. In this approach, the representations produced by a network are assigned to a fixed collection of vectors sampled from the uniform distribution on a sphere. After this, the network parameters are adjusted to make encodings closer to the assigned vectors. This approach is equivalent to CoLLike with reverse model factorization p θ (x, z) = p θ (z|x)p(x) and factorized Gaussian p θ (z|x). Another approaches that obtain clear probabilistic interpretation using CoLLike include: Sinkhorn Autoencoders (Patrini et al., 2019) , simultaneous clustering and representation learning (Asano et al., 2020) , and (Jeong & Song, 2019) . Bojanowski & Joulin (2017) noticed that NAT objective has Optimal Transport (OT) roots. OT framework can be used to measure discrepancy between distributions. Particularly, for a given nonnegative cost function c the optimal transport distance between distributions p δ and p ϵ is defined as When both p δ and p ϵ are empirical, the search space Γ becomes countable and finite. Now it contains only pairings between points in p δ and points in p ϵ . Given an arbitrary initial pairing, we can express all other pairings through permutation applied to either x or z. In this case, the cost becomes OT (p δ , p ϵ ) = min γ∈Γ(p δ ,pϵ) E x,z∼γ(x,z) [c(x, z)] OT (p δ , p ϵ ) = min π∈Π i c x i , z π(i) where Π is the set of all permutation functions. This expression is almost the CoLLike objective (2). Choosing the cost function c to be -log p θ (x, z) and switching to maximization make them equivalentfoot_3 . Thus, CoLLike bridges maximum likelihood methods with OT. This connection allows bringing latest developments in OT to improve likelihood-based methods. Furthermore, in Appendix C, we provide an example of the equivalence between CoLLike and Wasserstein distance. In the case, the model's complete likelihood plays the roles of both a mapping from z to x domain and a distance metric.

6. EXPERIMENTS

In this work, we focus on low-dimensional discrete latents. This type of latent variables allows to perform direct comparison with the exact likelihood. Furthermore, we emphasize our focus on learning useful z|x instead of simplifying the model with factorized x|z conditional. Models with tractable likelihood are perfect for comparing likelihood-based algorithms because they remove the problem of the likelihood estimation precision. For this type of models, all quantities of interest can be computed exactly. Moreover, tractable likelihood allows comparing CoLLike directly with MaL instead of its approximations like ELBO. We use MNIST (LeCun et al., 1998) and CIFAR (Krizhevsky, 2009) datasets for image modality and AG News (Zhang et al., 2015) for text domain. All these datasets are equipped with class labels. For images, we train a Glow-like normalizing flow conditioned on a discrete latent variable with 10 categories through all coupling layers. For text we use a Transformer Language Model conditioned on a discrete latent variable with 4 categories using additive embedding for all tokens. The size of the discrete variables is equal to the number of classes in the underlying dataset. Small number of categories allows to compute exact marginal likelihood value and speed up computation of the cost matrix C in Algorithm 1. The schematic representations of the architectures are provided in Figure 2 . CoLLike exhibits near-zero aggregated KL for all experiments. It implies that the model joint marginal in the latent domain perfectly matches the prior. For MaL, aggregated KL is zero only for AG News dataset which also has uninformative factorized joint. For other datasets, aggregated posterior significantly deviates from the prior. We also note that for MNIST dataset, MaL puts all probability mass to a single category in half of the runs. To estimate the quality of unsupervised classification, we perform the optimal assignment of latent categories to classes. For all cases except CoLLike objective on AG News dataset, the quality of the unsupervised classification is similar and is low. On AG News the unsupervised accuracy is exceptionally good. However, the variance of the proposed solution is relatively high. The standard deviation of the accuracy across 4 runs is 5.4 with the highest value of 87.1 and the lowest of 73.3. In the following section, we show that it is possible to achieve significantly higher unsupervised accuracy and lower variance by latent variable ensembling. Overall, CoLLike clearly outperforms MaL in the tractable likelihood setting. Moreover, it shows high unsupervised classification accuracy for text modality. For MaL, experiments depict a variety of possible failures from posterior collapse to degenerate aggregated posterior, which extends findings of (Lucas et al., 2019) to expressive models and exact likelihood. However, despite CoLLike producing informative latents in terms of MI, unsupervised classification might be challenging even in these cases. We believe that the key to high-performance unsupervised classification should be in the right inductive biases in conditioning and probabilistic model type.

6.1. LATENT ENSEMBLING

To reduce the high variance of CoLLike unsupervised classification accuracy and increase its accuracy we propose to perform ensembling of multiple models trained on the same data but using different seeds at initialization. Although there is no correspondence between labels for latent variable models, we can try to find the labels assignment based on the agreement between them. This approach is motivated by direct cluster ensembling (Boongoen & Iam-on, 2018) . The agreement between two labels of different ensemble members is the number of intersecting samples with those labels. To align the latents we iteratively find the assignment with the highest intersection between labels. Finally, we find the assignment between aligned latents and ground truth labels. In our experiments, we use 8 models per ensemble and train 4 independent ensembles. The simplest ensembling method is averaging of the predictions. It increases the mean unsupervised accuracy from 82.1 to 84.5 and reduces the standard deviation from 5.4 to 1.7. We further significantly improve these results by utilizing the agreement score, which is also used for alignment of the labels. We pick top-k models with highest maximum coherence across other models in the ensemble. Averaging predictions of those top-k models further increases accuracy to 86.6 and lowers the standard deviation to 0.2. We compare CoLLike results with the following unsupervised and supervised approaches: PET and iPET (Schick & Schütze, 2021) , EFL (Wang et al., 2021) , LM-BFF (Gao et al., 2021) , DocSCAN (Stammbach & Ash, 2021) . DocSCAN is purely unsupervised, while other approaches rely on engineering multiple textual descriptions of classes (prompts) or labeled data. All methods use heavy pre-trained Transformers (Vaswani et al., 2017) as an initialization, while in CoLLike we use small 2-layer Transformer with random initialization. Figure 3a presents the comparison of the methods. CoLLike clearly outperforms both unsupervised and most of the supervised methods. To determine how much training data we need without, possibly laborious, prompt engineering we use DeBERTa v3 (He et al., 2021) . We vary training set sizes from 32 to 2048 and apply additional ensembling of 8 models with different initializations and train-validation splits. Figure 3b reveals that CoLLike can be a better alternative to labeling more than a hundred samples, which, in turn, requires an extensive data analysis. Besides, note the high difference between the ensemble and the single model for small dataset sizes in a supervised setting, which is an interesting result by itself.

7. DISCUSSION AND FUTURE WORK

In this work, we propose to switch from the MaL paradigm of matching only marginals in the observed domain to CoLLike paradigm of finding an exact target joint by selection from a family of joints with desirable properties. Furthermore, we show that matching of marginals utilized by MaL corresponds to a specific choice of target joint, which motivates such failures as posterior collapse and divergence between target and model marginals in the latent domain. We experimentally show the ability of CoLLike to learn useful representations. Connection of CoLLike with OT allows to borrow techniques from the latter. For instance, Sinkhorn Relaxation (Cuturi, 2013) can be used to speed up the assignment problem. Investigation of alternatives to complete likelihood for target selection is of special interest. The right inductive biases for inducing useful properties using CoLLike are still to be discovered, at least until we want to get the desired without specifying what we want. We believe that the further extension of CoLLike to high-dimensional latents would be exciting and challenging. Other lines of research can be devoted to the application of other divergences to the constructed family of joint target candidates and extension of CoLLike to learnable priors.

8. REPRODUCIBILITY

To promote reproducibility we open-source our code the link is hidden for double-blind review, check the supplementary materials. Furthermore, we describe details of flow architectures in Appendix F.1 and Transformers in Appendix F.2. Along with models, we describe details of the training procedures and data pre-processing. We also devote special attention to setting all necessary seeds, including CUDA, and to removing stochasticity from the BPE tokenizer. Figure 4 : Example of three joint distributions with discrete x and z. Lines on the left and in the bottom depict the number of samples from empirical marginals with the corresponding value of the random variable. Squares reflect the joint probability value. Each line crossing the square corresponds to 1/N probability added to the corresponding (x, z) random variable pair. and z domain, the permutation can change the entropy of the joint. For instance, for the distribution on the left, the entropy H(p δ (x)p π (z|x)) = -( 1 4 log 1 4 + 1 4 log 1 4 + 1 2 log 1 2 ) ≈ 1.04 nats. For the distribution in the center, the entropy H(p δ (x)p π (z|x)) = -(4• 1 4 log 1 4 ) ≈ 1.39 nats. So, depending on π we might end up with more and less entropic distributions. However, when we restrict any empirical marginal to take only distinct values, like in the right part of Figure 4 , the situation changes. Namely, each distinct pair (x, z) can be chosen at most once, because to choose it twice we need a duplicate sample in both domains. This can be verified using the right part of Figure 4 . Just try to construct a joint with some square having grater than one line assuming that each value x marginal has only one line. Moreover, for x, this is a reasonable assumption since usually the domain of x is high-dimensional. For the case, the joint will contain N non-zero points each with probability 1/N . Thus, the entropy of the empirical distribution is equal - Under the assumption that p δ (x) contains only distinct elements the conditional p δπ (z|x) ≡ 1 for all values x from the p δ (x). So, choosing x uniquely determines the value of z, as can be seen from the right part of Figure 4 . Then the mutual information is given by M I(p δ (x)p π (z|x)) = E x,z∼p δ (x)pπ(z|x) log p δπ (z|x) p δπ (z) = E z∼p δπ (z) log 1 p δπ (z) = H(p δπ (z)) (8) So, the mutual information is always equal to the entropy of the empirical prior. It is possible to show that the value is the maximum possible one. This becomes obvious from the entropic factorization of the mutual information M I(p(x, z)) = H(p(z)) -E x∼p(x) [H(p(z|x))] Since the entropy is non-negative, the mutual information can be decreased only through the second term of (9), which equals 0 because z value is completely determined by x. When we try to extend the observations above to continuous cases we face the following challenge: empirical distribution has infinite values at the sample points. This drives the differential entropy as well as mutual information to infinity. However, adding noise to the empirical distribution solves this problem. Adding uniform noise with the interval smaller than the precision of the floating point makes the entropy finite and constant with respect to π. One can show that the resulting mutual information of the empirical joint also equals log N .

C WASSERSTEIN DISTANCE AND COLLIKE

Optimal Transport cost becomes Wasserstein distance when c is a metric. A very illustrative example from this family is equality of Wasserstein-2 (c is the Euclidean distance) and CoLLike for some setups. Specifically, the following objective can be produced both by Wasserstein distance and CoLLike L W (θ) = min π∈Π i x i -f θ (z π(i) ) To get this objective from OT perspective we define the model distribution to be produced by passing a fixed sample from prior through a deterministic decoder f θ (z). The result is an empiric distribution in x domain. Wasserstein distance between two empiric distributions is determined by optimal pairing between points from data distribution p δ (x) and model distribution spanned by empiric latents. The same objective is produced by factorized Gaussian p θ (x|z) and uniform prior p(z). This connection demonstrates that the model p θ (x|z) defines both mapping from z to x domain and "topology" of the x space (how we measure distance between objects). However, approaches based on the Wasserstein distance are limited to continuous variables, while CoLLike is applicable both for discrete and continuous domains. Moreover, CoLLike provides a probabilistic basis for the choice of the cost function.

D OPTIMAL PAIRING BY COMBINATORIAL OPTIMIZATION

Having at hand log-likelihood values for all possible x i z j pairs, we are ready to find the optimal permutation. A naive way to do so is to evaluate the sum 2 for every possible permutation π. Despite we need only to sum different pre-computed values, the search space for π is tremendous N !. However, we can cast this problem to a combinatorial optimization one. Following Papadimitriou & Steiglitz (1982) , the assignment problem is stated as follows minimize i,j c i,j a i,j subjected to i a i,j = 1 j = 0, ..., N j a i,j = 1 i = 0, ..., N a i,j ∈ {0, 1} where c i,j is the cost of picking the element i, j and a i,j is the indicator variable. The constrains of this problem define a set of permutation matrices. By choosing the cost to be negative loglikelihood and replacing indicator variable with permutation we end up with CoLLike objective. This combinatorial optimization problem can be solved efficiently with Hungarian algorithm Kuhn (1955) with complexity of O(N 3 ). As data pre-processing step, we used only dequantization with uniform noise, with the range equal to the quantization step.



We find the Greek letter δ especially suitable for data distribution because it is consonant with "data" and reflects the delta-function-like form of the empirical distribution. q ϕ (z|x) is an approximate posterior distribution parametrized by ϕ. This is similar to EM algorithmDempster et al. (1977) in the sense that EM alternates between maximization of the lower bound tightness (expectation step) and maximization of the the resulting tight bound (maximization step). A cautious reader might note that for continuous variables non-negativity of log p θ can be violated, however, all model densities used in practice are finite and the corresponding cost can be made positive just by an additive constant which does not change the optimization problem.



Figure 1: Illustration of the CoLLike (left) and MaL (right) objectives. Triangles depicts sample values. Filled circles represent p θ (x, z) for all possible (x, z) pairs. Double circles indicate optimal π. Bold lines and double circles are areas of the joint to be maximized.

Figure 2: Architectures used for image (left) and text (right) domains. For image domain, the three flow blocks are repeated 21 (CIFAR) and 14 (MNIST) times. Every coupling block is conditioned on z.

CoLLike ensemble vs. unsupervised and semisupervised approaches.

Supervised DeBERTa v3 base vs. CoLLike ensemble.

Figure 3: Comparison of ensembled CoLLike with supervised (a) and unsupervised/few-shot methods(b).

N . The observation above allows to easily derive mutual information of the empirical joint. Mutual information is defined as M I(p(x, z)) = E x,z∼p(x,z) log p(x, z) p(x)p(z) = E x,z∼p(x,z) log p(z|x) p(z)

Figure 5 depicts dependency between LAP problem size and time consumed by Hungarian algorithm to solve the problem. The input to the algorithm is a matrix C ∈ R B×B , where B is the size of the problem.

Figure 5: Time to solve LAP with Hungarian algorithm for different sizes of the problem.

Results for tractable categorical latents. MNIST, CIFAR -BPD; AG News -NLL. , p ϵ ) is the set of all joint distributions on x and z with marginals p δ (x) and p ϵ (z) respectively. Furthermore, if we use a parametric model p θ in place of p ϵ we can fit it by minimizing the distance. Note that in this case we minimize the function that already has a min function inside.

presents the results of training the latent variable models for CoLLike and MaL objectives averaged across 4 runs. Both objectives exhibit similar performance in terms of likelihood across datasets. However, other characteristics vary.MI is high for CoLLike objective on every dataset. Furthermore, it attains approximately maximal value for AG News and CIFAR. MI for MaL objective ranges from zero to values significantly lower than those of CoLLike. Zero MI indicates posterior collapse cases, which are mainly observed in ELBO optimization and recently discovered byLucas et al. (2019) for MaL applied to simple linear models. This experiment indicates important observation: posterior collapse can as well happen in deep latent variable models during optimization of exact MaL despite usually being corresponded to the structure of ELBO. Importantly, for the MNIST dataset, half of the experiments exhibits posterior collapse.

We use Glow-like normalizing flow for all image experiments. We choose the learning rate by starting from 1e -2 and gradually decrease it until there is no instabilities during training. No extensive learning rate search was done. Below we provide details on the model parameters.

A DERIVATION OF KL FORMS OF THE CONSIDERED OBJECTIVES

Equivalence between CoLLike objective (2) and its KL divergence form from Table 1 can be derived as follows:KL(p δ (x)p π (z|x)||p θ (x, z)) = E x,z∼p δ (x)pπ(z|x) log p δ (x)p π (z|x)where the first term in ( 6) is treated as constant with the assumption that all samples from p δ (x) take distinct values, which is reasonable for such high-dimensional objects as images, texts, and sounds. Thus, the KL form of the objective is equivalent to the CoLLike objective up to a multiplicative factor and an additive term. For proof of the constancy see Appendix B.The derivation of equivalence between (1) and its KL from Table 1 is as followsThe KL form of ELBO objective from Table 1 can be found in many works Zhao et al. (2019) ; Kingma & Welling (2019), however, we provide a derivation here to make the paper self-contained.

B ENTROPY AND MUTUAL INFORMATION OF EMPIRICAL JOINT

In this appendix we derive some useful properties of the empirical joint distributions produced by sampling from the prior. The joint distribution p δ (x)p π (z|x) depends on π. We focus on how π influences such distribution characteristics as entropy and mutual information.Consider a joint distribution over discrete x and z. This kind of distribution can be visualized as a table, such as depicted in Figure 4 . If there are multiple samples taking the same value both in x F.2 TRANSFORMERWe used simple two-layer transformer across our experiments. The model description:• number of layers: 2• hidden size: 128• feedforward dimension: 128• embedding dimension: 128• number of attention heads: 4• number of embeddings: 4000The training details are similar to CIFAR configuration:• epochs: 256• learning rate: 2e -4 -CIFAR• batch size: 64• validation part of the training set: 0.05• validation criterion: marginal likelihood• optimizer: Adam, β = (0.9, 0.999); ϵ = 1e -8In pre-processing step, we truncate the sequences longer than 192 tokens. Truncation affects less than 0.3% of the samples. Nevertheless, the tokenizer is trained on the full-length sequences. Data pre-processing can be summarized as follows:• maximum length truncation: 192• BPE tokenization• vocabulary size: 4000

