LEARNABLE TOPOLOGICAL FEATURES FOR PHYLOGE-NETIC INFERENCE VIA GRAPH NEURAL NETWORKS

Abstract

Structural information of phylogenetic tree topologies plays an important role in phylogenetic inference. However, finding appropriate topological structures for specific phylogenetic inference tasks often requires significant design effort and domain expertise. In this paper, we propose a novel structural representation method for phylogenetic inference based on learnable topological features. By combining the raw node features that minimize the Dirichlet energy with modern graph representation learning techniques, our learnable topological features can provide efficient structural information of phylogenetic trees that automatically adapts to different downstream tasks without requiring domain expertise. We demonstrate the effectiveness and efficiency of our method on a simulated data tree probability estimation task and a benchmark of challenging real data variational Bayesian phylogenetic inference problems.

1. INTRODUCTION

Phylogenetics is an important discipline of computational biology where the goal is to identify the evolutionary history and relationships among individuals or groups of biological entities. In statistical approaches to phylogenetics, this has been formulated as an inference problem on hypotheses of shared history, i.e., phylogenetic trees, based on observed sequence data (e.g., DNA, RNA, or protein sequences) under a model of evolution. The phylogenetic tree defines a probabilistic graphical model, based on which the likelihood of the observed sequences can be efficiently computed (Felsenstein, 2003) . Many statistical inference procedures therefore can be applied, including maximum likelihood and Bayesian approaches (Felsenstein, 1981; Yang & Rannala, 1997; Mau et al., 1999; Huelsenbeck et al., 2001) . Phylogenetic inference, however, has been challenging due to the composite parameter space of both continuous and discrete components (i.e., branch lengths and the tree topology) and the combinatorial explosion in the number of tree topologies with the number of sequences. Harnessing the topological information of trees hence becomes crucial in the development of efficient phylogenetic inference algorithms. For example, by assuming conditional independence of separated subtrees, Larget (2013) showed that conditional clade distributions (CCDs) can provide more reliable tree probability estimation that generalizes beyond observed samples. A similar approach was proposed to design more efficient proposals for tree movement when implementing Markov chain Monte Carlo (MCMC) algorithms for Bayesian phylogenetics (Höhna & Drummond, 2012) . Utilizing more sophisticated local topological structures, CCDs were later generalized to subsplit Bayesian networks (SBNs) that provide more flexible distributions over tree topologies (Zhang & Matsen IV, 2018) . Besides MCMC, variational Bayesian phylogenetics inference (VBPI) was recently proposed that leveraged SBNs and a structured amortization of branch lengths to deliver competitive posterior estimates in a more timely manner (Zhang & Matsen IV, 2019; Zhang, 2020; Zhang & Matsen IV, 2022) . Azouri et al. (2021) used a machine learning approach to accelerate maximum likelihood tree-search algorithms by providing more informative topology moves. Topological features have also been found useful for comparison and interpretation of the reconstructed phylogenies (Matsen IV, 2007; Hayati et al., 2022) . While these approaches prove effective in practice, they all rely on heuristic features (e.g., clades and subsplits) of phylogenetic trees that often require significant design effort and domain expertise, and may be insufficient for capturing complicated topological information. Graph Neural Networks (GNNs) are an effective framework for learning representations of graphstructured data. To encode the structural information about graphs, GNNs follow a neighborhood aggregation procedure that computes the representation vector of a node by recursively aggregating and transforming representation vectors of its neighboring nodes. After the final iteration of aggregation, the representation of the entire graph can also be obtained by pooling all the node embeddings together via some permutation invariant operators (Ying et al., 2018) . Many GNN variants have been proposed and have achieved superior performance on both node-level and graph-level representation learning tasks (Kipf & Welling, 2017; Hamilton et al., 2017; Li et al., 2016; Zhang et al., 2018; Ying et al., 2018) . A natural idea, therefore, is to adapt GNNs to phylogenetic models for automatic topological feature learning. However, the lack of node features for phylogenetic trees makes it challenging as most GNN variants assume fully observed node features at initialization. In this paper, we propose a novel structural representation method for phylogenetic inference that automatically learns efficient topological features based on GNNs. To obtain the initial node features for phylogenetic trees, we follow previous studies (Zhu & Ghahramani, 2002; Rossi et al., 2021) to minimize the Dirichlet energy, with one hot encoding for the tip nodes. Unlike these previous studies, we present a fast linear time algorithm for Dirichlet energy minimization by taking advantage of the hierarchical structure of phylogenetic trees. Moreover, we prove that these features are sufficient for identifying the corresponding tree topology, i.e., there is no information loss in our raw feature representations of phylogenetic trees. These raw node features are then passed to GNNs for more sophisticated structure representation learning required by downstream tasks. Experiments on a synthetic data tree probability estimation problem and a benchmark of challenging real data variational Bayesian phylogenetic inference problems demonstrate the effectiveness and efficiency of our method.

2. BACKGROUND

Notation A phylogenetic tree is denoted as (τ, q) where τ is a bifurcating tree that represents the evolutionary relationship of the species and q is a non-negative branch length vector that characterizes the amount of evolution along the edges of τ . The tip nodes of τ correspond to the observed species and the internal nodes of τ represent the unobserved characters (e.g., DNA bases) of the ancestral species. The transition probability P ij (t) from character i to character j along an edge of length t is often defined by a continuous-time substitution model (e.g., Jukes & Cantor (1969) ), whose stationary distribution is denoted as η. Let E(τ ) be the set of edges of τ , r be the root node (or any internal node if the tree is unrooted and the substitution model is reversible). Let Y = {Y 1 , Y 2 , . . . , Y M } ∈ Ω N ×M be the observed sequences (with characters in Ω) of length M over N species. Phylogenetic posterior Assuming different sites Y i , i = 1, . . . , M are independent and identically distributed, the likelihood of observing Y given the phylogenetic tree (τ, q) takes the form p(Y |τ, q) = M i=1 p(Y i |τ, q) = M i=1 a i η(a i r ) (u,v)∈E(τ ) P a i u a i v (q uv ), where a i ranges over all extensions of Y i to the internal nodes with a i u being the assigned character of node u. The above phylogenetic likelihood function can be computed efficiently through the pruning algorithm (Felsenstein, 2003) . Given a prior distribution p(τ, q) of the tree topology and the branch lengths, Bayesian phylogenetics then amounts to properly estimating the phylogenetic posterior p(τ, q|Y ) ∝ p(Y |τ, q)p(τ, q). Variational Bayesian phylogenetic inference Let Q φ (τ ) be an SBN-based distribution over the tree topologies and Q ψ (q|τ ) be a non-negative distribution over the branch lengths. VBPI finds the best approximation to p(τ, q|Y ) from the family of products of Q φ (τ ) and Q ψ (q|τ ) by maximizing the following multi-sample lower bound L K (φ, ψ) = E Q φ,ψ (τ 1:K ,q 1:K ) log 1 K K i=1 p(Y |τ i , q i )p(τ i , q i ) Q φ (τ i )Q ψ (q i |τ i ) ≤ log p(Y ) where Q φ,ψ (τ 1:K , q 1:K ) = K i=1 Q φ (τ i )Q ψ (q i |τ i ). To properly parameterize the variational distributions, a support of the conditional probability tables (CPTs) is often acquired from a sample Right: Subsequently, the tree topology with embedded node features are fed into a GNN model for more sophisticated tree structure representation learning required by downstream tasks. of tree topologies via fast heuristic bootstrap methods (Minh et al., 2013; Zhang & Matsen IV, 2019) . The branch length approximation Q ψ (q|τ ) is taken to be the diagonal Lognormal distribution Q ψ (q|τ ) = e∈E(τ ) p Lognormal (q e | µ(e, τ ), σ(e, τ )) where µ(e, τ ), σ(e, τ ) are amortized over the tree topology space via shared local structures (i.e., split and primary subsplit pairs (PSPs)), which are available from the support of CPTs. More details about structured amortization, VBPI and SBNs can be found in section 3.2.2 and Appendix A. Graph neural networks Let G = (V, E) denote a graph with node feature vectors X v for node v ∈ V , and N (v) denote the set of nodes adjacent to v. GNNs iteratively update the representation of a node by running a message passing (MP) scheme for T time steps. During each MP time step, the representation vectors of each node are updated based on the aggregated messages from its neighbors as follows h (t+1) v = UPDATE (t) h (t) v , m (t+1) v , m (t+1) v = AGG (t) h (t) u : u ∈ N (v) where h (t) v is the feature vector of node v at time step t, with initialization h (0) v = X v , UPDATE (t) is the update function, and AGG (t) is the aggregation function. A number of powerful GNNs with different implementations of the update and aggregation functions have been proposed (Kipf & Welling, 2017; Hamilton et al., 2017; Li et al., 2016; Veličković et al., 2018; Xu et al., 2019; Wang et al., 2019) . In additional to the local node-level features, GNNs can also provide features for the entire graph. To learn these global features, an additional READOUT function is often introduced to aggregate node features from the final iteration h G = READOUT h (T ) v : v ∈ V . READOUT can be any function that is permutation invariant to the node features.

3. PROPOSED METHOD

In this section, we propose a general approach that automatically learns topological features directly from phylogenetic trees. We first introduce a simple embedding method that provides raw features for the nodes of phylogenetic trees, together with an efficient linear time algorithm for obtaining these raw features and a discussion on some of their theoretical properties regarding tree topology representation. We then describe how these raw features can be adapted to learn efficient representations of certain structures of trees (e.g., edges) for downstream tasks.

3.1. INTERIOR NODE EMBEDDING

Learning tree structure features directly from tree topologies often requires raw node/edge features, as typically assumed in most GNN models. Unfortunately, this is not the case for phylogenetic models. Although we can use one hot encoding for the tip nodes according to their corresponding species (taxa names only, not the sequences), the interior nodes still lack original features. The first step of tree structure representation learning for phylogenetic models, therefore, is to properly input those missing features for the interior nodes. Following previous studies (Zhu & Ghahramani, 2002; Rossi et al., 2021) , we make a common assumption that the node features change smoothly across the tree topologies (i.e., the features of every node are similar to those of the neighbors). A widely used criterion of smoothness for functions defined on nodes of a graph is the Dirichlet energy. Given a tree topology τ = (V, E) and a function f : V → R d , the Dirichlet energy is defined as  (f, τ ) = (u,v)∈E f (u) -f (v) 2 . Let V = V b ∪ V o , X o = arg min X o (X o , X b , τ ) = arg min X o (u,v)∈E x u -x v 2 .

3.1.1. A LINEAR TIME TWO-PASS ALGORITHM

Note that the above Dirichlet energy function is convex, its minimizer therefore can be obtained by solving the following optimality condition ∂ (X o , X b , τ ) ∂X o ( X o ) = 0. It turns out that equation 3 has a close-form solution based on matrix inversion. However, as matrix inversion scales cubically in general, it is infeasible for graphs with many nodes. Fortunately, by leveraging the hierarchical structure of phylogenetic trees, we can design a more efficient linear time algorithm for the solution of equation 3 as follows. We first rewrite equation 3 as a system of linear equations v∈N (u) ( x u -x v ) = 0, ∀u ∈ V o , x v = x v , ∀v ∈ V b , where N (u) is the set of neighbors of node u. Given a topological ordering induced by the treefoot_0 , we can obtain the solution within a two-pass sweep through the tree topology, similar to the Thomas algorithm for solving tridiagonal systems of linear equations (Thomas, 1949) . In the first pass, we traverse the tree in a postorder fashion and express the node features as a linear function of those of their parents, x u = c u x πu + d u , for all the nodes expect the root node, where π u denotes the parent node of u. More specifically, we first initialize c u = 0, d u = x u for all leaf nodes u ∈ V b . For all the interior nodes except the root node, we compute c u , d u recursively as follows (see a detailed derivation in Appendix B) c u = 1 |N (u)| -v∈ch(u) c v , d u = v∈ch(u) d v |N (u)| -v∈ch(u) c v , where ch(u) denotes the set of child nodes of u. In the second pass, we traverse the tree in a preorder fashion and compute the solution by back substitution. Concretely, at the root node r, given equation 5 for all the child nodes from the first pass, we can compute the node feature directly from equation 4 as below x r = v∈ch(r) d v |N (r)| -v∈ch(r) c v . ( ) For all the other interior nodes, the node features can be obtained via equation 5 by substituting the learned features for the parent nodes. We summarize our two-pass algorithm in Algorithm 1. Moreover, the algorithm is numerically stable due to the following lemma (proof in Appendix C). Lemma 1. Let λ = min u∈V o \{r} |N (u)|. For all interior node u ∈ V o \{r}, 0 ≤ c u ≤ 1 λ-1 . Besides bifurcating phylogenetic trees, the above two-pass algorithm can be easily adapted to interior node embedding for general tree-shaped graphs with given tip node features. Algorithm 1 A Two-pass Algorithm for Interior Node Embedding 1: Input: Tree topology τ = (V, E), where V = V b ∪ V o ; Features for the tip nodes {x u |u ∈ V b }. 2: Initialize c u = 0, d u = x u , ∀u ∈ V b . 3: Traverse the tree topology in a postorder fashion. For any interior node u that is not the root node, compute c u and d u as in equation 6. 4: Traverse the tree topology in a preorder fashion. For the root node r, compute the node feature as in equation 7. For any other interior node u, compute the node feature as x u = c u x πu + d u . 5: return { x u |u ∈ V o }.

3.1.2. TREE TOPOLOGY REPRESENTATION POWER

In this section, we discuss some theoretical properties regarding the tree topology presentation power of the node features introduced above. We start with a useful lemma that elucidates an important behavior of the solution to the linear system 4, which is similar to the solutions to elliptic equations. Lemma 2 (Extremum Principle). Let { x u ∈ R d |u ∈ V } be a set of d-dimensional node features that satisfies equations 4. ∀1 ≤ n ≤ d, let X[n] = { x u [n]|u ∈ V } be the set of the n-th components of node features. Then, ∀1 ≤ n ≤ d, we have: (i) the extremum values (i.e., maximum and minimum) of X[n] can be achieved at some tip nodes; (ii) if the extremum values are achieved at some interior nodes, then X[n] has only one member, or in other words, x u [n] is the same ∀u ∈ V . Theorem 1. Let N be the number of tip nodes. Let { x u ∈ R N |u ∈ V } be the solution to the linear system 4 with one hot encoding for the tip nodes. Then, ∀u ∈ V o , we have (i) 0 < x u [n] < 1, ∀1 ≤ n ≤ N, and (ii) N n=1 x u [n] = 1. The complete proofs of Lemma 2 and Theorem 1 are provided in Appendix C. When the tip node features are linearly independent, a similar proposition holds when we consider the coefficients of the linear combination of the tip node features for the interior node features instead. Corollary 1. Suppose that the tip node features are linearly independent, the interior node features obtained from the solution to the linear system 4 all lie in the interior of the convex hull of all tip node features. The proof is provided in Appendix C. The following lemma reveals a key property of the nodes that are adjacent to the boundary of the tree topology in the embedded feature space. Lemma 3. Let { x u |u ∈ V } be the solution to the linear system 4, with linearly independent tip node features. Let { x u = v∈V b a v u x v |u ∈ V o } be the convex combination representations of the interior node features. For any tip node v ∈ V b , we have u * = arg max u∈V o a v u ⇔ u * ∈ N (v). Theorem 2 (Identifiability). Let X o = { x u |u ∈ V o } and Z o = { z u |u ∈ V o } be the sets of interior node features that minimizes the Dirichlet energy for phylogenetic tree topologies τ x and τ z respectively, given the same linearly independent tip node features. If X o = Z o , then τ x = τ z . The proofs of Lemma 3 and Theorem 2 are provided in Appendix C. By Theorem 2, we see that the proposed node embeddings are complete representations of phylogenetic tree topologies with no information loss.

3.2. STRUCTURAL REPRESENTATION LEARNING VIA GRAPH NEURAL NETWORKS

Using node embeddings introduced in section 3.1 as raw features, we now show how to learn more sophisticated representations of tree structures for different phylogenetic inference tasks via GNNs. Given a tree topology τ , let {h : v ∈ V } be the output features after the final iteration of GNNs. We feed these output features of GNNs into a multi-layer perceptron (MLP) to get a set of learnable features for each node h v = MLP (0) h (T ) v , ∀ v ∈ V, before adapting to different downstream tasks, as demonstrated in the following examples.

3.2.1. ENERGY BASED MODELS FOR TREE PROBABILITY ESTIMATION

Our first example is on graph-level representation learning of phylogenetic tree topologies. Let T denote the entire tree topology space. Given learnable node features of tree topologies, one can use a permutation invariant function g to obtain graph-level features and hence create an energy function F φ : T → R that assigns each tree topology a scalar value as follows F φ (τ ) = MLP (1) (h G ), h G = g ({h v : v ∈ V }) . where g • MLP (0) can be viewed as a READOUT function in section 2. This allows us to construct energy based models (EBMs) for tree probability estimation q φ (τ ) = exp (-F φ (τ )) Z(φ) , Z(φ) = τ ∈T exp (-F φ (τ )) . As Z(φ) is usually intractable, we can employ noise contrastive estimation (NCE) (Gutmann & Hyvärinen, 2010) to train these energy based models. Let p n be some noise distribution that has tractable density and allows efficient sampling procedures. Let D φ (τ ) = log q φ (τ ) -log p n (τ ). We can train D φfoot_1 to minimize the following objective function (NCE loss) J(φ) = -E τ ∼p data (τ ) log (S (D φ (τ ))) + E τ ∼pn(τ ) log (1 -S (D φ (τ ))) , where S(x) = 1 1+exp(-x) is the sigmoid function. It is easy to verify that the minimum is achieved at D φ * (τ ) = log p data (τ ) -log p n (τ ). Therefore, q φ * (τ ) = p data (τ ) = p n (τ ) exp (D φ * (τ )).

3.2.2. BRANCH LENGTH PARAMETERIZATION FOR VBPI

The branch length parameterization in VBPI so far has relied on hand-engineered features (i.e., splits and PSPs) for the edges on tree topologies. Let S r denote the set of splits and S psp denote the set of PSPs. The simple split-based parameterization assigns parameters ψ µ , ψ σ for splits in S r . The mean and standard deviation for each edge e on τ are then given by the associated parameters of the corresponding split e/τ as follows µ(e, τ ) = ψ µ e/τ , σ(e, τ ) = ψ σ e/τ . The more flexible PSP parameterization assigns additional parameters for PSPs in S psp and adds the associated parameters of the corresponding PSPs e/ /τ to equation 8 to refine the mean and standard deviation parameterization µ(e, τ ) = ψ µ e/τ + s∈e/ /τ ψ µ s , σ(e, τ ) = ψ σ e/τ + s∈e/ /τ ψ σ s . Although these heuristic features prove effective, they often require substantial design effort, a sample of tree topologies for feature collection, and can not adapt themselves during training which makes it difficult for amortized inference over different tree topologies. Based on the learnable node features, we can design a more flexible branch length parameterization that is capable of distilling more effective structural information of tree topologies for variational approximations. For each edge e = (u, v) on τ , similarly as in section 3.2.1, one can use a permutation invariant function f to obtain edge-level features and transform them into the mean and standard deviation parameters as follows µ(e, τ ) = MLP µ (h e ) , σ(e, τ ) = MLP σ (h e ) , h e = f ({h u , h v }) . ( ) Compared to heuristic feature based parameterizations in 8 and 9, learnable topological feature based parameterizations in 10 allow much richer design for the branch length distributions across different tree topologies and do not require pre-sampled tree topologies for feature collection.

4. EXPERIMENTS

In this section, we test the effectiveness and efficiency of learnable topological features for phylogenetic inference on the two aforementioned benchmark tasks: tree probability estimation via energy based models and branch length parameterization for VBPI. Following Zhang & Matsen IV (2019) , in VBPI we used the simplest SBN for the tree topology variational distribution, and the CPT supports were estimated from ultrafast maximum likelihood phylogenetic bootstrap trees using UFBoot (Minh et al., 2013) . The code is available at https://github.com/zcrabbit/vbpi-gnn. Experimental setup. We evaluate five commonly used GNN variants with the following convolution operators: graph convolution networks (GCN), graph isomorphism operator (GIN), GraphSAGE operator (SAGE), gated graph convolution operator (GGNN) and edge convolution operator (EDGE). See more details about these convolution operators in Appendix F. In addition to the above GNN variants, we also considered a simpler model that skips all GNN iterations (i.e., T = 0) and referred to it as MLP in the sequel. All GNN variants have 2 GNN layers (including the input layer), and all involved MLPs have 2 layers. We used summation as our permutation invariant aggregation function for graph-level features and maximization for edge-level features. All models were implemented in Pytorch (Paszke et al., 2019) with the Adam optimizer (Kingma & Ba, 2015) .We designed our experiments with the goals of (i) verifying the effectiveness of GNN-based EBMs for tree topology estimation and (ii) verifying the improvement of GNN-based branch length parameterization for VBPI over the baseline approaches (i.e., split and PSP based parameterizations) and investigating how helpful the learnable topological features are for reducing the amortization gaps.

4.1. SIMULATED DATA TREE PROBABILITY ESTIMATION

We first investigated the representative power of learnable topological features for approximating distributions on phylogenetic trees using energy based models (EBMs), and conducted experiments on a simulated data set. We used the space of unrooted phylogenetic trees with 8 leaves, which contains 10395 unique trees in total. Similarly as in Zhang & Matsen IV (2019), we generated a target distribution p 0 (τ ) by drawing a sample from the symmetric Dirichlet distribution Dir(β1) of order 10395 with a pre-selected arbitrary order of trees. The concentration parameter β is used to control the diffuseness of the target distribution and was set to 0.008 to provide enough information for inference while allowing for adequate diffusion in the target. As mentioned earlier in section 3.2.1, we used noise contrastive estimation (NCE) to train our EBMs where we set the noise distribution p n (τ ) to be the uniform distribution. Results were collected after 200,000 parameter updates. Note that the minimum NCE loss in this case is J * = -2JSD (p 0 (τ ) p n (τ )) + 2 log 2, where JSD(• •) is the Jensen-Shannon divergence. Figure 2 shows the empirical performance of different methods. From the left plot, we see that the NCE losses converge rapidly and the gaps between NCE losses for the GNN variants and the best NCE loss J * (dashed red line) are close to zero, demonstrating the representative power of learnable topological features on phylogenetic tree probability estimations. The evolution of KL divergences (middle plot) is consistent with the NCE losses. Compared to MLP, all GNN variants perform better, indicating that the extra flexibility provided by GNN iterations is crucial for tree probability estimation that would benefit from more informative graph-level features. Although the raw features from interior node embedding contain all information of phylogenetic tree topologies, we see that distilling effective structural information from them is still challenging. This makes GNN models that are by design more capable of learning geometric representations a favorable choice. The right plot compares the probability mass approximations provided by EBMs using MLP and GGNN (which performs the best among all GNN variants), to the ground truth p 0 (τ ). We see that EBMs using GGNN consistently provide accurate approximations for trees across a wide range of probabilities. On the other hand, estimates provided by those using MLP are often of large bias, except for a few trees with high probabilities.

4.2. REAL DATA VARIATIONAL BAYESIAN PHYLOGENETIC INFERENCE

The second task we considered is VBPI, where we compared learnable topological feature based branch length parameterizations to heuristic feature based parameterizations (denoted as Split and PSP resepectively) proposed in the original VBPI approach (Zhang & Matsen IV, 2019) . All methods were evaluated on 8 real datasets that are commonly used to benchmark Bayesian phylogenetic inference methods (Hedges et al., 1990; Garey et al., 1996; Yang & Yoder, 2003; Henk et al., 2003; Lakner et al., 2008; Zhang & Blackwell, 2001; Yoder & Yang, 2004; Rossman et al., 2001; Höhna & Drummond, 2012; Larget, 2013; Whidden & Matsen IV, 2015) . These datasets, which we call DS1-8, consist of sequences from 27 to 64 eukaryote species with 378 to 2520 site observations. We concentrate on the most challenging part of the Bayesian phylogenetics: joint learning of the tree topologies and the branch lengths, and assume a uniform prior on the tree topology, an i.i.d. exponential prior (Exp(10)) for the branch lengths and the simple Jukes & Cantor (1969) substitution model. We gathered the support of CPTs from 10 replicates of 10000 ultrafast maximum likelihood bootstrap trees (Minh et al., 2013) . We set K = 10 for the multi-sample lower bound, with a schedule λ n = min(1, 0.001 + n/100000), going from 0.001 to 1 after 100000 iterations. The Monte Carlo gradient estimates for the tree topology parameters and branch length parameters were obtained via VIMCO (Mnih & Rezende, 2016) and the reparameterization trick (Kingma & Welling, 2014) respectively. Results were collected after 400,000 parameter updates. Table 1 shows the estimates of the evidence lower bound (ELBO) and the marginal likelihood using different branch length parameterizations on the 8 benchmark datasets, including the results for the stepping-stone (SS) method (Xie et al., 2011) , which is one of the state-of-the-art sampling based methods for marginal likelihood estimation. For each data set, a better approximation would lead to a smaller variance of the marginal likelihood estimates. We see that solely using the raw features, MLP-based parameterization already outperformed the Split and PSP baselines by providing tighter lower bounds. With more expressive representations of local structures enabled by GNN iterations, GNN-based parameterization further improved upon MLP-based methods, indicating the importance of harnessing local topological information for flexible branch length distributions. Moreover, when used as importance distributions for marginal likelihood estimation via importance sampling, MLP and GNN variants provide more steady estimates (less variance) than Split and PSP respectively. All variational approaches compare favorably to SS and require much fewer samples. The left plot in Figure 3 shows the evidence lower bounds as a function of the number of parameter updates on DS1. Although neural networks based parameterization adds to the complexity of training in VI, we see that by the time Split and PSP converge, MLP and EDGEfoot_2 achieve comparable (if not better) lower bounds and quickly surpass these baselines as the number of iteration increases. As diagonal Lognormal branch length distributions were used for all parameterization methods, how these variational distributions were amortized over tree topologies under different parameterizations therefore is crucial for the overall approximation performance. To better understand this effect of amortized inference, we further investigated the amortization gapsfoot_3 of different methods on individual trees in the 95% credible set of DS1 as in Zhang (2020) . The middle and right plots in Figure 3 show the amortization gaps of different parameterization methods on each tree topology τ . We see the amortization gaps of MLP and EDGE are considerably smaller than those of Split and PSP respectively, showing the efficiency of learnable topological features for amortized branch length distributions. Again, incorporating more local topological information is beneficial, as evidenced by the significant improvement of EDGE over MLP. More results about the amortization gaps can be found in Table 2 in the appendix.

5. CONCLUSION

We presented a novel approach for phylogenetic inference based on learnable topological features. By combining the raw node features that minimize the Dirichlet energy with modern GNN variants, our learnable topological features can provide efficient structural information without requiring domain expertise. In experiments, we demonstrated the effectiveness of our approach for tree probability estimation on simulated data and showed that our method consistently outperforms the baseline approaches for VBPI on a benchmark of real data sets. Future work would investigate more sophisticated GNNs for phylogenetic trees, and applications to other phylogenetic inference tasks where efficiently leveraging structural information of tree topologies is of great importance. Left: Some rooted phylogenetic tree examples. Middle: The corresponding SBN assignments. For ease of illustration, subsplit (W, Z) is represented as W Z in the graph. The dashed gray subgraphs represent fake splitting processes where splits are deterministically assigned, and are used purely to complement the networks such that the overall network has a fixed structure. Right: The SBN for these examples. This figure is adapted from Zhang & Matsen IV (2019) .

A SUBSPLIT BAYESIAN NETWORKS

Subsplit Bayesian networks (SBNs) were introduced by Zhang & Matsen IV (2018) for density estimation on tree topologies and were later used to provide a family of variational distributions for Bayesian phylogenetic inference (Zhang & Matsen IV, 2019) . Given a leaf set X of size N , one can define a subsplit Bayesian network B X to be a Bayesian network whose nodes take on subsplit or singleton clade values that represent the local topological structures of trees (Figure 4 ). This formulation allows us to represent rooted tree topologies as SBN assignments. More specifically, one can follow the splitting process (see the solid dark subgraphs in Figure 4 , middle) of the tree and assign the subsplits to the corresponding nodes along the way to get a unique subsplit decomposition of the tree topology. Given the subsplit decomposition of a rooted tree τ = {s 1 , s 2 , . . .}, where s 1 is the root subsplit, the SBN-induced tree probability of τ is p sbn (T = τ ) = p(S 1 = s 1 ) i>1 p(S i = s i |S πi = s πi ) where S i denote the subsplit-or singleton-clade-valued random variables at node i and π i is the index set of the parents of S i . As Bayesian networks, SBN-induced distributions are all naturally normalized. We can also adjust the structures of SBNs for a wide range of expressive distributions, as long as they remain valid directed acyclic graphs (DAGs). In practice, the simplest SBN (the one with a full and complete binary tree structure as shown in Figure 4 ) is often found to be good enough. The SBN framework also generalizes to unrooted trees, which are the most common type of phylogenetic trees. The key is to view unrooted trees as rooted trees with missing roots. Marginalizing out the unobserved root nodes leads to the SBN probability estimates for unrooted trees p sbn (T u = τ ) = s1∼τ p(S 1 = s 1 ) i>1 p(S i = s i |S πi = s πi ) where ∼ means all root subsplits that are compatible with τ (i.e., root subsplits of the edges of τ ). We can evaluate the SBN probabilities of tree topologies efficiently through a two pass algorithm (Zhang & Matsen IV, 2018) . As Bayesian networks, SBNs also allow fast sampling procedures (i.e., ancestral sampling). These properties make SBNs a natural choice for variational inference. Parameterizing SBNs in VBPI requires a sufficiently large subsplit support of CPTs (i.e., where the associate conditional probabilities are allowed to take nonzero values) that covers favorable parent child subsplit pairs from trees with high posterior probabilities. In practice, a simple bootstrap-based approach has been found effective for providing such a support (Zhang & Matsen IV, 2019) . Let S r denote the set of root subsplits (e.g., the splits) in the support and S ch|pa denote the set of parent-child subsplit pairs in the support. The CPTs can be defined via the softmax function as follows p(S 1 = s 1 ) = exp(φ s1 ) sr∈Sr exp(φ sr ) , p(S i = s|S πi = t) = exp(φ s|t ) s∈S •|t exp(φ s|t ) .

B THE TWO-PASS ALGORITHM FOR INTERIOR NODE EMBEDDING

Due to the hierarchical structure of phylogenetic trees, we derive a linear time two-pass algorithm for solving the linear equations 4. In the first pass, we traverse the tree in a postorder fashion and express the node features as a linear function of those of their parents as follows x u = c u x πu + d u , for all the nodes except the root node, where π u is the parent node of u. For all leaf nodes u ∈ V b , this is straightforward: c u = 0, d u = x u . For any interior node u that is not the root node, as a result of postorder traversal, c v , d v is available ∀v ∈ ch(u) when u is visited. Therefore, we can rewrite the equation about u in 4 as |N (u)| x u = x πu + v∈ch(u) x v = x πu + v∈ch(u) (c v x u + d v ) . This implies x u = 1 |N (u)| -v∈ch(u) c v • x πu + v∈ch(u) d v |N (u)| -v∈ch(u) c v , which gives the recursive updating formula in equation 6. In the second pass, we traverse the tree in a preorder fashion. We first visit the root node r. From the first pass, c v , d v is available ∀v ∈ ch(r). Similarly, from equation 4, we have |N (r)| x r = v∈ch(r) x v = v∈ch(r) (c v x r + d v ) Therefore, x r = v∈ch(r) d v |N (r)| -v∈ch(r) c v . For any other interior node u, as a result of preorder traversal, x πu is available when u is visited. We can compute its node feature x u via equation 5.

C PROOFS FOR LEMMAS, COROLLARIES, AND THEOREMS

Proof for Lemma 1 Lemma 1. Let λ = min u∈V o \{r} |N (u)|. For all interior node u ∈ V o \{r}, 0 ≤ c u ≤ 1 λ-1 . Proof. We prove by induction. |N (u)| ≥ 2, ∀u ∈ V o , we have λ ≥ 2. Note that c u = 0, ∀u ∈ V b . Suppose 0 ≤ c v ≤ 1 λ-1 , ∀v ∈ ch(u), it suffices to show that 0 ≤ c u ≤ 1 λ-1 as long as u is not the root node. From the recursive updating formula, we have 0 ≤ c u = 1 |N (u)| -v∈ch(u) c v ≤ 1 |N (u)| -v∈ch(u) 1 λ-1 = 1 |N (u)| -|N (u)|-1 λ-1 . As 2 ≤ λ ≤ |N (u)|, (|N (u)| -λ) λ -2 λ -1 ≥ 0 ⇒ |N (u)| -λ ≥ |N (u)| -λ λ -1 ⇒ |N (u)| - |N (u)| -1 λ -1 ≥ λ -1. Therefore, 0 ≤ c u ≤ 1 λ-1 . Remark. For bifurcating phylogenetic trees, we have λ = 3 ⇒ 0 ≤ c u ≤ 1 2 . Proof for Lemma 2 Lemma 2 (Extremum Principle). Let { x u ∈ R d |u ∈ V } be a set of d-dimensional node features that satisfies equations 4. ∀1 ≤ n ≤ d, let X[n] = { x u [n]|u ∈ V } be the set of the n-th components of node features. Then, ∀1 ≤ n ≤ d, we have: (i) the extremum values (i.e., maximum and minimum) of X[n] can be achieved at some tip nodes; (ii) if the extremum values are achieved at some interior nodes, then X[n] has only one member, or in other words, x u [n] is the same ∀u ∈ V . Proof. From equations 4, we have x u 1 |N (u)| v∈N (u) x v , ∀u ∈ V o . In other words, for any interior node, its node feature is the mean of those of its neighbors. Therefore, ∀1 ≤ n ≤ d, the extremum values of X[n] can be achieved at the boundary of the graph, i.e., the tip nodes. On the other hand, if the extremum value of X[n] is achieved at some interior node u, then the extremum is also achieved at all the neighbors of u. Since the tree topology is connected, this implies that x u [n] is a constant ∀u ∈ V . Proof for Theorem 1 Theorem 1. Let N be the number of tip nodes. Let { x u ∈ R N |u ∈ V } be the solution to the linear system 4 with one hot encoding for the tip nodes. Then, ∀u ∈ V o , we have Note that { x u ∈ R N |u ∈ V } minimizes the Dirichlet energy, the equality has to hold which implies that x u -x v is parallel to the hyperplane S, ∀(u, v) ∈ E. Therefore, x u ∈ S, ∀u ∈ V o as the tree topology is connected and the tip node features are already in S. Proof for Corollary 1 Corollary 1. Suppose that the tip node features are linearly independent, the interior node features obtained from the solution to the linear system 4 all lie in the interior of the convex hull of all tip node features. Proof. Let A be a matrix whose columns corresponds to the tip node features. From the linear system 4, we see that all interior node features lie in the column space of A and hence can be uniquely represented as x u = A y u , ∀u ∈ V as the columns are linearly independent. Moreover, { y u ∈ R N |u ∈ V } satisfies the same linear system 4 which concludes the proof.



This is trivial for rooted trees since they are directed. For unrooted trees, we can choose an interior node as the root node and use the topological ordering of the corresponding rooted trees. Here Z(φ) is taken as a free parameter and is included into φ. We use EDGE as an example here for branch length parameterization since it can learn edge features (see Appendix F). All GNN variants (except the simple GCN) performed similarly in this example (see Table1). The amortization gap on a tree topology τ is defined as L(Q * |τ ) -L(Q ψ |τ ), where L(Q ψ |τ ) is the ELBO of the approximating distribution Q ψ (q|τ ) and L(Q * |τ ) is the maximum lower bound that can be achieved with the same variational family. See more details in Zhang (2020);Cremer et al. (2018).



where V b denotes the set of leaf nodes and V o denotes the set of interior nodes. Let X b = {x v |v ∈ V b } be the set of one hot embeddings for the leaf nodes. The interior node features X o = {x v |v ∈ V o } then can be obtained by minimizing the Dirichlet energy

(0) v : v ∈ V } be the raw features and {h (T ) v

Figure 2: Comparison of learnable topological feature based EBMs for probability mass estimation of unrooted phylogenetic trees with 8 leaves using NCE. Left: NCE loss. Middle: KL divergence. Right: EBM approximations vs ground truth probabilities. The NCE loss and KL divergence results were obtained from 10 independent runs and the error bars represent one standard deviation.

Figure 3: Performance on DS1. Left: Lower bounds. Middle & Right: Amortization gaps on trees in the 95% credible sets.

Figure 4: A simple subsplit Bayesian network for a leaf set that contains 4 species A, B, C and D. Left: Some rooted phylogenetic tree examples. Middle: The corresponding SBN assignments. For ease of illustration, subsplit (W, Z) is represented as W Z in the graph. The dashed gray subgraphs represent fake splitting processes where splits are deterministically assigned, and are used purely to complement the networks such that the overall network has a fixed structure. Right: The SBN for these examples. This figure is adapted from Zhang & Matsen IV (2019).

(i) 0 < x u [n] < 1, ∀1 ≤ n ≤ N, and (ii) N n=1 x u [n] = 1.Proof. As the tip node features are one hot vectors, (i) follows immediately from Lemma 2. LetS = {x ∈ R N | N n=1x[n] = 1} and P S be the projection onto S. Since { x u ∈ R N |u ∈ V } solves the linear system 4, it minimizes the Dirichlet energy. Now consider its projection {P S ( x u ) ∈ R N |u ∈ V }. Since one hot vectors are in S, P S ( x u ) = x u , ∀u ∈ V b . Note that P S is a projection operator, we have(u,v)∈E P S ( x u ) -P S ( x v ) 2 ≤ (u,v)∈E x u -x v 2 .

v∈N (u) ( y u -y v ) = 0, ∀u ∈ V o , y v = y v , ∀v ∈ V b ,where{y v ∈ R N |v ∈ V b } are one hot encoding vectors. By Theorem 1, we have ∀u ∈ V o (i) 0 < y u [n] < 1, ∀1 ≤ n ≤ N, and(ii)N n=1 y u [n] = 1.

Evidence Lower bound (ELBO) and marginal likelihood (ML) estimates of different methods across 8 benchmark datasets for Bayesian phylogenetic inference. The marginal likelihood estimates of all variational methods are obtained via importance sampling using 1000 samples, and the results (in units of nats) are averaged over 100 independent runs with standard deviation in brackets. Results for stepping-stone (SS) are from Zhang & Matsen IV (2019)(using 10 independent MrBayes(Ronquist et al., 2012) runs, each with 4 chains for 10,000,000 iterations and sampled every 100 iterations).

ACKNOWLEDGMENTS

This work was supported by National Natural Science Foundation of China (grant no. 12201014), as well as National Institutes of Health grant AI162611. The research of the author was support in part by the Key Laboratory of Mathematics and Its Applications (LMAM) and the Key Laboratory of Mathematical Economics and Quantitative Finance (LMEQF) of Peking University. The author is grateful for the computational resources provided by the High-performance Computing Platform of Peking University. The author appreciates the anonymous ICLR reviewers for their constructive feedback.

annex

Proof for Lemma 3 Lemma 3. Let { x u |u ∈ V } be the solution to the linear system 4, with linearly independent tip node features. Let { x u = v∈V b a v u x v |u ∈ V o } be the convex combination representations of the interior node features. For any tip node v ∈ V b , we haveProof. For any tip node v ∈ V b , let v ∈ N (v) be the adjacent interior node of v. It suffices to show that a v v > a v u , ∀u ∈ V o \{v}. As {x v |v ∈ V b } are linearly independent, from equations 4 we haveNow consider a new tree topology τ that has v removed from τ . Note that equation 11 still holds for all interior nodes except v, the maximum value of {a v u |u ∈ V \{v}} therefore can be achieved at either the boundary of τ or v. As a v u = 0, ∀u ∈ V b \{v} and a v v > 0 (by Corollary 1), the maximum value hence is achieved at a v v , i.e., a v v ≥ a v u , ∀u ∈ V o . Now suppose there exists an interior nodebe the sets of interior node features that minimizes the Dirichlet energy for phylogenetic tree topologies τ x and τ z respectively, given the same linearly independent tip node features. IfProof. Consider the case for unrooted trees first. We prove by induction on the number of tip nodes N of the tree topology. For N = 3, the tree topology is trivial. If N > 3, then for each tip node, Lemma 3 identifies its adjacent node by convex combination coefficient maximization. As the number of interior nodes is less than the number of tip nodes, there must be two tip nodes that connect to the same interior node. Merging the two tip nodes into their shared neighbor reduces the problem to size N -1, with the shared neighbor being the new tip node. It is easy to check that for this new set of node features, the tip node features are also linearly independent. Therefore, the remaining part of the tree topology can be recovered by induction hypothesis. The proof for rooted trees is similar. 2022) use symmetry-preserving neural networks to explicitly encode permutation invariance of phylogenetic trees. However, their method requires enumeration of all permutations of tip nodes that would leave tree topologies unchanged and hence is challenging to extend to datasets with many sequences. Jiang et al. (2022) learn hyperbolic embeddings of gene sequences for phylogenetic tree placement and updates. While these embeddings maybe useful for sequentially updating species trees with only a handful of genes, they are not suitable for learning representations Published as a conference paper at ICLR 2023 of tree topologies and their local structures with a given taxa set. Fioravanti et al. (2018) use CNNs to incorporate phylogenetic information for the classification of metagenomics data where the phylogenetic tree is assumed to be known. Their method, therefore, is also not suitable for phylogenetic inference where the tree topology is unknown and needs to be inferred from the data.

F DETAILS ON GRAPH CONVOLUTIONAL OPERATORS

The followings are the update and aggregate functions for the graph convolutional operators used in our experiments.• Graph convolutional networks (GCN)where d u is the degree of node u.• Graph isomorphism networks (GIN)where (t) can be either a learnable parameter or a fixed scalar.• GraphSAGE operator (SAGE)• Gated graph convolutional operator (GGNN)where GRU (t) is a Gated recurrent unit, a gating mechanism in recurrent neural networks.• Edge convolutional operator (EDGE) t+1) u→v , e (t+1) u→v = MLP (t) where || means concatenation and e (t+1) u→v is the edge feature from node u to node v.

