NEURAL POTTS MODEL

Abstract

We propose the Neural Potts Model objective as an amortized optimization problem. The objective enables training a single model with shared parameters to explicitly model energy landscapes across multiple protein families. Given a protein sequence as input, the model is trained to predict a pairwise coupling matrix for a Potts model energy function describing the local evolutionary landscape of the sequence. Couplings can be predicted for novel sequences. A controlled ablation experiment assessing unsupervised contact prediction on sets of related protein families finds a gain from amortization for low-depth multiple sequence alignments; the result is then confirmed on a database with broad coverage of protein sequences.

1. INTRODUCTION

When two positions in a protein sequence are in spatial contact in the folded three-dimensional structure of the protein, evolution is not free to choose the amino acid at each position independently. This means that the positions co-evolve: when the amino acid at one position varies, the assignment at the contacting site may vary with it. A multiple sequence alignment (MSA) summarizes evolutionary variation by collecting a group of diverse but evolutionarily related sequences. Patterns of variation, including co-evolution, can be observed in the MSA. These patterns are in turn associated with the structure and function of the protein (Göbel et al., 1994) . Unsupervised contact prediction aims to detect co-evolutionary patterns in the statistics of the MSA and infer structure from them. The standard method for unsupervised contact prediction fits a Potts model energy function to the MSA (Lapedes et al., 1999; Thomas et al., 2008; Weigt et al., 2009) . Various approximations are used in practice including mean field (Morcos et al., 2011) , sparse inverse covariance estimation (Jones et al., 2011) , and pseudolikelihood maximization (Balakrishnan et al., 2011; Ekeberg et al., 2013; Kamisetty et al., 2013) . To construct the MSA for a given input sequence, a similarity query is performed across a large database to identify related sequences, which are then aligned to each other. Fitting the Potts model to the set of sequences identifies statistical couplings between different sites in the protein, which can be used to infer contacts in the structure (Weigt et al., 2009) . Contact prediction performance depends on the depth of the MSA and is reduced when few related sequences are available to fit the model. In this work we consider fitting many models across many families simultaneously with parameter sharing across all the families. We introduce this formally as the Neural Potts Model (NPM) objective. The objective is an amortized optimization problem across sequence families. A Transformer model is trained to predict the parameters of a Potts model energy function defined by the MSA of each input sequence. This approach builds on the ideas in the emerging field of protein language models (Alley et al., 2019; Rives et al., 2019; Heinzinger et al., 2019) , which proposes to fit a single model with unsupervised learning across many evolutionarily diverse protein sequences. We extend this core idea to train a model to output an explicit energy landscape for every sequence. To evaluate the approach, we focus on the problem setting of unsupervised contact prediction for proteins with low-depth MSAs. Unsupervised structure learning with Potts models performs poorly when few related sequences are available (Jones et al., 2011; Kamisetty et al., 2013; Moult et al., 2016) . Since larger protein families are likely to have structures available, the proteins of greatest interest for unsupervised structure prediction are likely to have lower depth MSAs (Tetchner et al., 2014) . This is especially a problem for higher organisms, where there are fewer related genomes (Tetchner et al., 2014) . The hope is that for low-depth MSAs, the parameter sharing in the neural model will improve results relative to fitting an independent Potts model to each family. < l a t e x i t s h a 1 _ b a s e 6 4 = " V p B Y x 8 K B t b t m Q y a n T u T K s J e P Y q 0 = " > A A A C C 3 i c b V C 9 T s M w G H T K X w l / B U a W i A o J M U Q J q k T H S i y M R a K 0 q A 2 V 4 z i t V d u J 7 C + I K u o b s L H C S z A i V h 6 C d + A h c N s M 0 H K S p d P d d / b n C 1 P O N H j e l 1 V a W V 1 b 3 y h v 2 l v b O 7 t 7 l f 2 D W 5 1 k i t A W S X i i O i H W l D N J W 8 C A 0 0 6 q K B Y h p + 1 w d D n 1 2 w 9 U a Z b I G x i n N B B 4 I F n M C A Y j 3 f W G G P L 2 5 P 6 s X 6 l 6 r j e D s 0 z 8 g l R R g W a / 8 t 2 L E p I J K o F w r H X X 9 1 I I c q y A E U 4 n d i / T N M V k h A e 0 a 6 j E g u o g n y 0 8 c U 6 M E j l x o s y R 4 M z U 3 4 k c C 6 3 H I j S T A s N Q L 3 p T 8 T + v m 0 F c D 3 I m 0 w y o J P O H 4 o w 7 k D j T 3 z s R U 5 Q A H x u C i W J m V 4 c M s c I E T E d 2 b x b M p 9 f 2 S S I E l p F 2 g T 5 O b F O P v 1 j G M m m f u 3 7 N 9 f 3 r W r V R L 5 o q o y N 0 j E 6 R j y 5 Q A 1 2 h J m o h g g R 6 R i / o 1 X q y 3 q x 3 6 2 M + W r K K z C H 6 A + v z B w m 3 m 3 c = < / l a t e x i t > Ŵ ⇤ < l a t e x i t s h a 1 _ b a s e 6 4 = " u j 0 x U A H a O l U i x 9 x E u w Y 6 N 3 / m A O s = " > A A A C D 3 i c b V D L S s N A F J 3 U V 6 2 v q k s 3 w S L o J i R S 0 K X g x m U F a w t N K J P p r R 0 6 M 4 k z N 2 I J / Q d 3 b v U n X I p b P 8 F / 8 C O c p F 3 4 O j B w O O e e O 5 c T p 4 I b 9 P 0 P p 7 K w u L S 8 U l 2 t r a 1 v b G 7 V t 3 e u T Z J p B m 2 W i E R 3 Y 2 p A c A V t 5 C i g m 2 q g M h b Q i c f n h d + 5 A 2 1 4 o q 5 w k k I k 6 Y 3 i Q 8 4 o W i n q 9 E M c A d L D 8 O 7 + q F 9 v + J 5 f w v 1 L g j l p k D l a / f p n O E h Y J k E h E 9 S Y X u C n G O V U I 2 c C p r U w M 5 B S N q Y 3 0 L N U U Q k m y s u j p + 6 B V Q b u M N H 2 K X R L 9 X s i p 9 K Y i Y z t p K Q 4 M r + 9 Q v z P 6 2 U 4 P I 1 y r t I M Q b H Z R 8 N M u J i 4 R Q P u g G t g K C a W U K a 5 v d V l I 6 o p Q 9 t T L S y D e b G 2 z x I p q R o Y D + F + W r P 1 B L / L + E s 6 x 1 7 Q 9 I L g s t k 4 O 5 0 3 V S V 7 Z J 8 c k o C c k D N y Q V q k T R i 5 J Y / k i T w 7 D 8 6 L 8 + q 8 z U Y r z j y z S 3 7 A e f 8 C 0 q y c 8 g = = < / l a t e x i t > W ✓ (x) < l a t e x i t s h a 1 _ b a s e 6 4 = " y 4 r T R q 5 t B e p q e w K M p v + w j y e 7 7 E 0 = " > A A A C B X i c b V C 9 T s M w G H T K X w l / B U a W i A q J K U p Q J T p W Y m E s g t J K T V Q 5 r t N a 9 U 9 k O 1 W r q C s b K 7 w E I 2 L l O X g H H g I n z Q A t J 1 k 6 3 X 1 n f 7 4 o o U R p z / u y K h u b W 9 s 7 1 V 1 7 b / / g 8 K h 2 f P K o R C o R 7 i B B h e x F U G F K O O 5 o o i n u J R J D F l H c j S Y 3 u d + d Y q m I 4 A 9 6 n u C Q w R E n M U F Q G + k + m M 4 G t b r n e g W c d e K X p A 5 K t A e 1 7 2 A o U M o w 1 4 h C p f q + l + g w g 1 I T R P H C D l K F E 4 g m c I T 7 h n L I s A q z Y t W F c 2 G U o R M L a Q 7 X T q H + T m S Q K T V n k Z l k U I / V q p e L / 3 n 9 V M f N M C M 8 S T X m a P l Q n F J H C y f / t z M k E i N N 5 4 Z A J I n Z 1 U F j K C H S p h 0 7 K I J Z f u 0 A C c Y g H y p X 4 9 n C N v X 4 q 2 W s k + 6 V 6 z d c 3 7 9 r 1 F v N s q k q O A P n 4 B L 4 4 B q 0 w C 1 o g w 5 A Y A S e w Q t 4 t Z 6 s N + v d + l i O V q w y c w r + w P r 8 A a J F m R U = < / l a t e x i t > x < l a t e x i t s h a 1 _ b a s e 6 4 = " y 4 r T R q 5 t B e p q e w K M p v + w j y e 7 7 E 0 We investigate the NPM objective in a controlled ablation experiment on a group of related protein families in PFAM (Finn et al., 2016) . In this artificial setting, information can be generalized by the pre-trained shared parameters to improve unsupervised contact prediction on a subset of the MSAs that have been artificially truncated to reduce their number of sequences. We then study the model in the setting of a large dataset without artificial reduction, training the model on MSAs for UniRef50 sequences. In this setting there is also an improvement on average for low depth MSAs both for sequences in the training set as well as for sequences not in the training set. = " > A A A C B X i c b V C 9 T s M w G H T K X w l / B U a W i A q J K U p Q J T p W Y m E s g t J K T V Q 5 r t N a 9 U 9 k O 1 W r q C s b K 7 w E I 2 L l O X g H H g I n z Q A t J 1 k 6 3 X 1 n f 7 4 o o U R p z / u y K h u b W 9 s 7 1 V 1 7 b / / g 8 K h 2 f P K o R C o R 7 i B B h e x F U G F K O O 5 o o i n u J R J D F l H c j S Y 3 u d + d Y q m I 4 A 9 6 n u C Q w R E n M U F Q G + k + m M 4 G t b r n e g W c d e K X p A 5 K t A e 1 7 2 A o U M o w 1 4 h C p f q + l + g w g 1 I T R P H C D l K F E 4 g m c I T 7 h n L I s A q z Y t W F c 2 G U o R M L a Q 7 X T q H + T m S Q K T V n k Z l k U I / V q p e L / 3 n 9 V M f N M C M 8 S T X m a P l Q n F J H C y f / t z M k E i N N 5 4 Z A J I n Z 1 U F j K C H S p h 0 7 K I J Z f u 0 A C c Y g H y p X 4 9 n C N v X 4 q 2 W s k + 6 V 6 z d c 3 7 9 r 1 F v N s q k q O A P n 4 B L 4 4 B q 0 w C 1 o g w 5 A Y A S e w Q t 4 t Z 6 s N + v d + l i O V q w y c w r + w P r 8 A a J F m R U = < /

2. BACKGROUND

Multiple sequence alignments An MSA is a set of aligned protein sequences that are evolutionarily related. MSAs are constructed by retrieving related sequences from a sequence database and aligning the returned sequences using a heuristic. An MSA can be viewed as a matrix where each row is a sequence, and columns contain aligned positions after removing insertions and replacing deletions with gap characters.

Potts model

The generalized Potts model defines a Gibbs distribution over a protein sequence (x 1 , . . . , x L ) of length L with the negative energy function: -E(x) = i h i (x i ) + ij J ij (x i , x j ) Which defines potentials h i for each position in the sequence, and couplings J ij for every pair of positions. The parameters of the model are W = {h, J} the set of fields and couplings respectively. The distribution p(x; W ) is obtained by normalization as exp{-E(x; W )}/Z(W ). Since the normalization constant is intractable, pseudolikelihood is commonly used to fit the parameters (Balakrishnan et al., 2011; Ekeberg et al., 2013) . Pseudolikelihood approximates the likelihood of a sequence x as a product of conditional distributions: PL (x; W ) =i log p(x i |x -i ; W ). To estimate the Potts model, we take the expectation: L PL (W ) = E x∼M [ PL (x; W )] over an MSA M. In practice, we have a finite set of sequences M in the MSA to estimate Eq. (2). L 2 regularization ρ(W ) = λ J J 2 + λ h h 2 is added, and sequences are reweighted to account for redundancy (Morcos et al., 2011) . We write the regularized finite sample estimator as: LPL (W ) = 1 M eff M m=1 w m [ PL (x m ; W )] + ρ(W ) Which sums over all the M sequences of the finite MSA M, weighted with w m summing collectively to M eff . The finite sample estimate of the parameters Ŵ * is obtained by minimizing LPL . Idealized MSA Notice how in Eq. ( 2), we idealized the MSA M as a distribution, defined by the protein family. We consider the set of sequences actually retrieved in the MSA M in Eq. (3) as a finite sample from this underlying idealized distribution. For some protein families this sample will contain more information than for others, depending on what sequences are present in the database. We will refer to W * as a hypothetical idealized estimate of the parameters to explain how the Neural Potts Model can improve on the finite sample estimate Ŵ * for low-depth MSAs.

2.1. AMORTIZED OPTIMIZATION

We review amortized optimization (Shu, 2017) , a generalization of amortized variational inference (Kingma & Welling, 2013; Rezende et al., 2014) that uses learning to predict the solution to continuous optimization problems to make the computation more tractable and potentially generalize across problem instances. We are interested in repeatedly solving expensive optimization problems W * (x) = arg min W L(W ; x), where W ∈ R m is the optimization variable, x ∈ R n is the input or conditioning variable to the optimization problem, and L : R m × R n → R is the objective. We assume W * (x) is unique. We consider the setting of having a distribution over optimization problems with inputs x ∼ p(x), and the arg min of those optimization problems W * (x). Amortization uses learning to leverage the shared structure present across the distribution, e.g. a solution W * (x) is likely correlated with another solution W * (x ). Assuming an underlying regularity of the data and loss L, we can imagine learning to predict the outcome of the optimization problem with an expressive model W θ (x) such that hopefully W θ ≈ W * . Modeling and learning W θ (x) are the key design decisions when using amortization. Modeling approaches. In this paper we consider models W θ (x) that directly predict the solution to Eq. ( 4) with a neural network, an approach which follows fully amortized variational inference models and the meta-learning method (Mishra et al., 2017) . The model can also leverage the objective information L(W ; x) and gradient information ∇ W L(W ; x), e.g. by predicting multiple candidate solutions W and selecting the most optimal one. This is sometimes referred to as semi-amortization or unrolled optimization-based models and is considered in Gregor & LeCun (2010) Learning approaches. There are two main classes of learning approaches for amortization: arg min θ E p(x) L(W θ (x); x) (5) arg min θ E p(x) W θ (x) -W * (x) 2 2 . (6) Gradient-based approaches leverage gradient information of the objective L and optimize Eq. ( 5) whereas regression-based approaches optimize a distance to ground-truth solutions W * , such as the squared L 2 distance in Eq. ( 6). Prior work has shown that models trained with these objectives can learn to predict the optimal W * directly as a function of x. Given enough regularity of the domain, if we observe new (test) samples x ∼ p(x) we expect the model to generalize and predict the solution to the original optimization problem Eq. ( 4). Gradient-based approaches have the computational advantage of not requiring the expensive ground-truth solution W * while regression-based approaches are less susceptible to poor local optima in L. Gradient-based approaches are used in variational inference (Kingma & Welling, 2013) , style transfer (Chen & Schmidt, 2016), meta learning (Finn et al., 2017; Mishra et al., 2017) , and reinforcement learning, e.g. for the policy update in model-free actor-critic methods (Sutton & Barto, 2018) . Regression-based approaches are more common in control for behavioral cloning and imitation learning (Duriez et al., 2017; Ratliff et al., 2007; Bain & Sammut, 1995) .

3. NEURAL POTTS MODEL

In Eq. ( 2) we introduced the Potts model for a single MSA M (aligned set of sequences x), to optimize W * = {h * , J * } = arg min W Ex∼M[ PL (x; W )]. As per Eq. ( 5) We will now introduce a neural network to estimate Potts model parameters from a single sequence: {h θ (x), J θ (x)} = W θ (x) with a single forward pass. We propose minimizing the following objective for the NPM parameters θ, which directly minimizes the Potts model losses in expectation over our data distribution x ∼ D and their MSAs x ∼ M(x): L NPM (θ) = E x∼D E x∼M(x) PL (x; W θ (x)) Amortization gap "underfitting" Inductive gain < l a t e x i t s h a 1 _ b a s e 6 4 = " e d o m f 7 Q t D 1 K H y P f w Y g l F T b M H S d A = " > A A A C B H i c b V C 7 T s M w F H V 4 l v A q M L J E V E i I I U o Q i I 6 V W B i L o A + p D Z X j O K 1 V P y L b Q V R R V z Z W + A k 2 x M p / 8 A 9 8 B E 6 a A V q O Z O n o n H v s 6 x M m l C j t e V / W 0 v L K 6 t p 6 Z c P e 3 N r e 2 a 3 u 7 b e V S C X C L S S o k N 0 Q K k w J x y 1 N N M X d R G L I Q o o 7 4 f g q 9 z s P W C o i + J 2 e J D h g c M h J T B D U R r r t 3 J 8 O q j X P 9 Q o 4 i 8 Q v S Q 2 U a A 6 q 3 / 1 I o J R h r h G F S v V 8 L 9 F B B q U m i O K p 3 U 8 V T i A a w y H u G c o h w y r I i l W n z r F R I i c W 0 h y u n U L 9 n c g g U 2 r C Q j P J o B 6 p e S 8 X / / N 6 q Y 7 r Q U Z 4 k m r M 0 e y h O K W O F k 7 + b y c i E i N N J 4 Z A J I n Z 1 U E j K C H S p h 2 7 X w S z / N o B E o x B H i l X 4 8 e p b e r x 5 8 t Y J O 0 z 1 7 9 w v Z v z W q N e F l U B h + A I n A A f X I I G u A Z N 0 A I I D M E z e A G v 1 p P 1 Z r 1 b H 7 P R J a v M H I A / s D 5 / A F T 1 m H U = < / l a t e x i t > W ⇤ < l a t e x i t s h a 1 _ b a s e 6 4 = " x 7 d x v 8 y i w o V k 0 3 e 7 K K 2 G s v f 6 Y B U = " > A A A C C n i c b V B L T s M w F H T 4 l v I r s G Q T U S E h F l G C Q H R Z i Q 3 L I t E P a k P l u E 5 r 1 X Y i + w V R R b k B O 7 Z w C X a I L Z f g D h w C J + 0 C W k a y N J p 5 Y z 9 P E H O m w X W / r K X l l d W 1 9 d J G e X N r e 2 e 3 s r f f 0 l G i C G 2 S i E e q E 2 B N O Z O 0 C Q w 4 7 c S K Y h F w 2 g 7 G V 7 n f f q B K s 0 j e w i S m v s B D y U J G M B j p r j f C k L a z + 9 N + p e o 6 b g F 7 k X g z U k U z N P q V 7 9 4 g I o m g E g j H W n c 9 N w Y / x Q o Y 4 T Q r 9 x J N Y 0 z G e E i 7 h k o s q P b T Y u H M P j b K w A 4 j Z Y 4 E u 1 B / J 1 I s t J 6 I w E w K D C M 9 7 + X i f 1 4 3 g b D m p 0 z G C V B J p g + F C b c h s v P f 2 w O m K A E + M Q Q T x c y u N h l h h Q m Y j s q 9 I p j m 1 / Z J J A S W A + 0 A f c z K p h 5 v v o x F 0 j p z v A v H v T m v 1 m u z o k r o E B 2 h E + S h S 1 R H 1 6 i B m o g g g Z 7 R C 3 q 1 n q w 3 6 9 3 6 m I 4 u W b P M A f o D 6 / M H Z 2 e b Q g = = < / l a t e x i t > Ŵ ⇤ < l a t e x i t s h a 1 _ b a s e 6 4 = " n W + u f + 2 K u N E 6 p z Z l O P / e x 9 3 o P w 8 = " > A A A C D n i c b V D L S g M x F M 3 4 r P V V d e l m s A h 1 U 2 Z E s c u C G 5 c V 7 A M 6 Q 8 m k t 2 1 o k h m T O 2 I Z + g / u 3 O p P u B O 3 / o L / 4 E e Y P h b a e i B w O O e e m 8 u J E s E N e t 6 X s 7 K 6 t r 6 x m d v K b + / s 7 u 0 X D g 4 b J k 4 1 g z q L R a x b E T U g u I I 6 c h T Q S j R Q G Q l o R s P r i d 9 8 A G 1 4 r O 5 w l E A o a V / x H m c U r R Q 2 O w E O A G k p e H g 8 6 x S K X t m b w l 0 m / p w U y R y 1 T u E 7 6 M Y s l a C Q C W p M 2 / c S D D O q k T M B 4 3 y Q G k g o G 9 I + t C 1 V V I I J s + n R Y / f U K l 2 3 F 2 v 7 F L p T 9 X c i o 9 K Y k Y z s p K Q 4 M I v e R P z P a 6 f Y q 4 Q Z V 0 m K o N j s o 1 4 q X I z d S Q N u l 2 t g K E a W U K a 5 v d V l A 6 o p Q 9 t T P p g G s 8 n a D o u l p K p r y g i P 4 7 y t x 1 8 s Y 5 k 0 z s v + Z d m 7 v S h W K / O i c u S Y n J A S 8 c k V q Z I b U i N 1 w s g 9 e S Y v 5 N V 5 c t 6 c d + d j N r r i z D N H 5 A + c z x 8 v F J y 9 < / l a t e x i t > W ✓ (x) < l a t e x i t s h a 1 _ b a s e 6 4 = " O T 8 U 1 o w H W 6 b F e 4 T g l 1 X D a U G F S U I = " > A A A C D n i c b V D L S s N A F J 3 4 r P V V d e k m W A T d l E Q U u y y 4 c e G i g r V C E 8 J k c t s O z k z i z I 1 Y Q v / B n V v 9 C X f i 1 l / w H / w I J 7 U L X w c G D u f c c + d y 4 k x w g 5 7 3 7 s z M z s 0 v L F a W q s s r q 2 v r t Y 3 N S 5 P m m k G H p S L V V z E 1 I L i C D n I U c J V p o D I W 0 I 2 v T 0 q / e w v a 8 F R d 4 C i D U N K B 4 n 3 O K F o p D A Z n e 9 0 o w C E g 3 Y 9 q d a / h T e D + J f 6 U 1 M k U 7 a j 2 E S Q p y y U o Z I I a 0 / O 9 D M O C a u R M w L g a 5 A Y y y q 7 p A H q W K i r B h M X k 6 L G 7 a 5 X E 7 a f a P o X u R P 2 e K K g 0 Z i R j O y k p D s 1 v r x T / 8 3 o 5 9 p t h w V W W I y j 2 9 V E / F y 6 m b t m A m 3 A N D M X I E s o 0 t 7 e 6 b E g 1 Z W h 7 q g a T Y F G u j V g q J V W J a S D c j a u 2 H v 9 3 G X / J 5 U H D P 2 p 4 5 4 f 1 V n N a V I V s k x 2 y R 3 x y T F r k l L R J h z B y Q x 7 I I 3 l y 7 p 1 n 5 8 V 5 / R q d c a a Z L f I D z t s n z Q C c g g = = < / l a t e x i t > L(W ✓ ) < l a t e x i t s h a 1 _ b a s e 6 4 = " L e D X n i A C s l w h h O g m / C v 0 d Z Y Q T x o = " > A A A C D 3 i c b V D L S g M x F M 3 U V 6 2 v q k s 3 g 0 W o L s q M K H Z Z c O P C R Q X 7 g M 6 0 Z N K 0 D U 0 y Q 3 J H L E M / w p 1 b / Q l 3 4 t Z P 8 B / 8 C D P t L L T 1 Q O B w z j 0 3 l x N E n G l w n C 8 r t 7 K 6 t r 6 R 3 y x s b e / s 7 h X 3 D 5 o 6 j B W h D R L y U L U D r C l n k j a A A a f t S F E s A k 5 b w f g 6 9 V s P V G k W y n u Y R N Q X e C j Z g B E M R u p 6 w 9 u y N 8 K Q t K b d s 9 N e s e R U n B n s Z e J m p I Q y 1 H v F b 6 8 f k l h Q C Y R j r T u u E 4 G f Y A W M c D o t e L G m E S Z j P K Q d Q y U W V P v J 7 O q p f W K U v j 0 I l X k S 7 J n 6 O 5 F g o f V E B G Z S Y B j p R S 8 V / / M 6 M Q y q f s J k F A O V Z P 7 R I O Y 2 h H Z a g d 1 n i h L g E 0 M w U c z c a p M R V p i A K a r g z Y J J u r Z H Q i G w 7 O s K 0 M d p w d T j L p a x T J r n F f e y 4 t x d l G r V r K g 8 O k L H q I x c d I V q 6 A b V U Q M R p N A z e k G v 1 p P 1 Z r 1 b H / P R n J V l D t E f W J 8 / a / K c 1 A = = < / l a t e x i t > L( Ŵ ⇤ )

Generalization loss

Training objective < l a t e x i t s h a 1 _ b a s e 6 4 = " 3 b J E e q Q b d 3 W c 3 R o B c l y I y K u y x A 0 = " > A A A C G X i c d V D L S i N B F K 3 W c X Q y D 6 O C G z e F Y S C z a a r z M F k G 3 L h w k Q F j A u n Q V F c q S W F V d V N 1 W w x t f 4 k 7 t / o T 7 s S t K / 9 h P m I q M Q M z 4 h w o O J x z z 6 3 L i V M p L B D y 4 q 2 t f 9 j 4 u L n 1 q f T 5 y 9 d v 2 + W d 3 X O b Z I b x H k t k Y g Y x t V w K z X s g Q P J B a j h V s e T 9 + O J 4 4 f c v u b E i 0 W c w T / l I 0 a k W E 8 E o O C k q 7 4 c z C n k 4 P S 2 i P D Q q 7 5 4 W R b X / I y p X i E + a Q a 0 W Y O L X W 6 2 j R s 0 R E p B G v Y 4 D n y x R Q S t 0 o / K v c J y w T H E N T F J r h w F J Y Z R T A 4 J J X p T C z P K U s g s 6 5 U N H N V X c j v L l / Q X + 7 p Q x n i T G P Q 1 4 q f 6 d y K m y d q 5 i N 6 k o z O x b b y G + 5 w 0 z m L R H u d B p B l y z 1 4 8 m m c S Q 4 E U Z e C w M Z y D n j l B m h L s V s x k 1 l I G r r B Q u g / l i b c Q S p a g e W x / 4 V V F y 9 f z p A P + f n N f 8 o O m T n 4 1 K p 7 0 q a g s d o E N U R Q F q o Q 4 6 Q V 3 U Q w x d o 1 t 0 h + 6 9 G + / B e / S e X k f X v F V m D / 0 D 7 / k 3 d k u h J g = = < / l a t e x i t >

LPL (W )

< l a t e x i t s h a 1 _ b a s e 6 4 = " 4 v 6 2 p o r 1 v p M m B 7 i 8 a A y w e c A l H D Y = " > A A A C B 3 i c b V C 9 T s M w G H T K X w l / B U a W i A q p L F G C Q H S s x M L A U C T 6 I z V R 5 T h O a 2 o 7 k e 0 g q q g P w M Y K L 8 G G W H k M 3 o G H w E k z Q M t J l k 5 3 3 9 m f L 0 g o k c p x v o z K y u r a + k Z 1 0 9 z a 3 t n d q + 0 f d G W c C o Q 7 K K a x 6 A d Q Y k o 4 7 i i i K O 4 n A k M W U N w L J l e 5 3 3 v A Q p K Y 3 6 l p g n 0 G R 5 x E B E G l p a 4 3 u m n 0 T o e 1 u m M 7 B a x l 4 p a k D k q 0 h 7 V v L 4 x R y j B X i E I p B 6 6 T K D + D Q h F E 8 c z 0 U o k T i C Z w h A e a c s i w 9 L N i 2 5 l 1 o p X Q i m K h D 1 d W o f 5 O Z J B J O W W B n m R Q j e W i l 4 v / e Y N U R U 0 / I z x J F e Z o / l C U U k v F V v 5 1 K y Q C I 0 W n m k A k i N 7 V Q m M o I F K 6 I N M r g l l + 7 R D F j E E e S l v h x 5 m p 6 3 E X y 1 g m 3 T P b v b C d 2 / N 6 q 1 k W V Q V H 4 B g 0 g A s u Q Q t c g z b o A A T u w T N 4 A a / G k / F m v B s f 8 9 G K U W Y O w R 8 Y n z 8 l V p l r < / l a t e x i t > L(W ) < l a t e x i t s h a 1 _ b a s e 6 4 = " e q Z n I P 7 6 s 0 To compute the loss for a given sequence x we compute the Potts model parameters W θ (x), and evaluate its pseudo-likelihood loss PL on a set of sequences x from the MSA constructed with x as query sequence. This fits exactly in "amortized optimization" in Section 2.1 Eq. ( 5): we train a model to predict the outcome of a set of highly related optimization problems. One key extension to the described amortized optimization setup is that the model W θ estimates the Potts Model parameters from only the MSA query sequence x as input rather than the full MSA M(x). Thus, our model must learn to distill the protein energy landscape into its parameters, since it cannot look up related proteins during runtime. A full algorithm is given in Appendix A. H n d 7 E g b q K j d k O 6 D T Y = " > A A A C D H i c b V D L S g M x F M 3 4 r O O r 6 t J N s A g u p M y I Y j d C w Y 2 4 q m A f 0 C k l k 0 n b 0 D y G J C O W Y X 7 B n V v 9 C X f i 1 n / w H / w I M + 0 s t P V A 4 H D O P c n N C W N G t f G 8 L 2 d p e W V 1 b b 2 0 4 W 5 u b e / s l v f 2 W 1 o m C p M m l k y q T o g 0 Y V S Q p q G G k U 6 s C O I h I + 1 w f J 3 7 7 Q e i N J X i 3 k x i 0 u N o K O i A Y m S s 1 G 3 D K x i k o 9 P b I O u X K 1 7 V m w I u E r 8 g F V C g 0 S 9 / B 5 H E C S f C Y I a 0 7 v p e b H o p U o Z i R j I 3 S D S J E R 6 j I e l a K h A n u p d O V 8 7 g s V U i O J D K H m H g V P 2 d S B H X e s J D O 8 m R G e l 5 L x f / 8 7 q J G d R 6 K R V x Y o j A s 4 c G C Y N G w v z / M K K K Y M M m l i C s q N 0 V 4 h F S C B v b k h t M g 2 l + b R 9 L z p G I d N W Q x 8 y 1 9 f j z Z S y S 1 l n V v 6 h 6 d + e V e q 0 o q g Q O w R E 4 A T 6 4 B H V w A x q g C T C Q 4 B m 8 g F f n y X l z 3 p 2 P 2 e i S U 2 Q O w B 8 4 n z + b R Similar to the original Potts model, we need to add a regularization penalty ρ(W ) to the main objective. For a finite sample of N different query sequences {x n }, and a corresponding sample of N× M aligned sequences {x m n } from MSA M(x n ), the finite sample regularized loss, i.e. NPM training objective, becomes: LNPM (θ) = N n=1 1 M eff (n) M m=1 w m n [ PL (x m n ; W θ (x n ))] + ρ(W θ (x n )) Inductive generalization gain (see Fig. 2 ) is when the Neural Potts Model improves over the individual Potts model. Intuitively this is possible because the individual Potts Models are not perfect estimates (finite/biased MSAs), while the shared parameters of W θ can transfer information between related protein families and from pre-training with another objective like masked language modeling (MLM). Let us start with the normal amortized optimization setting, where we expect an amortization gap (Cremer et al., 2018) . The amortization gap means that W θ (x) will be behind the optimal W * for the objective L: L(W θ (x)) > L(W * ). This is closely related to underfitting: the model W θ is not flexible enough to capture W * (x). However, recall that in the Potts model setting, there is a finite-sample training objective L (Eq. ( 8)), with minimizer Ŵ * . We can expect an amortization gap in the training objective; however this amortization gap can now be advantageous. Even if the amortized solution W θ (x) is near-optimal on L, it can likely find a more generalizable region of the overparametrized domain W by parameter sharing of θ, allowing it to transfer information between related instances. The inductive bias of W θ (x) can allow the neural amortized estimate to generalize better, especially when the finite sample M is poor. This inductive bias depends on the choice of model class for W θ , its pre-training, as well as the shared structure between the protein families in the dataset. Concretely we will consider for the generalization loss L not just the pseudo-likelihood loss on test MSA sequences, but also the performance on downstream validation objectives like predicting contacts, a proxy for the model's ability to capture the underlying structure of the protein. We will show that for some samples L(W θ (x)) < L( Ŵ * ), i.e. there is an inductive generalization gain. This is visually represented in Fig. 2 ; and Table 1 compares amortized optimization and NPM, making a connection to multi-task learning (Caruana, 1998) . Additionally, we could frame NPM as a hypernetwork, a neural network that predicts the weights of second network (in this case the Potts model) as in, e.g., Gomez & Schmidhuber (2005)  Solo: W ∈ R n Amor: W θ : R d → R n +learner class (B) Potts → NPM PLL, finite MSA M : L(W ) = m PL (x m ; W ) Distr L(W ) = E[ PL (x; W )] or Contact pred Neural Potts Ex[ L M (W θ (x))] Solo: W ∈ R n +regularization Amor: W θ : R d → R n +learner class (C) ML → MTL (Multi-task learning) ERM: L(f θ ) = m (f θ (x m ), y m ) L(f θ ) = E xy (f θ (x), y) Multi-task learning: T t=1 [ Lt (f t θ )] for T related tasks Solo: f θ : R d → R +regularization +learner class MTL: f t θ : R d → R + param sharing f t θ In summary, the goal for the NPM is to "distill" an ensemble of Potts models into a single feedforward model. From a self-supervised learning perspective, rather than supervising the model with the input directly, we use supervision from an energy landscape around the input.

4. EXPERIMENTS

In Section 4.1 we present results on a small set of related protein domain families from Pfam, where we artificially reduce the MSA depth for a few families to study the inductive generalization gain from the shared parameters. In Section 4.2 we present results on a large Transformer trained on MSAs for all of UniRef50. For the main representation g θ (x) we use a bidirectional transformer model (Vaswani et al., 2017) . To compute the four-dimensional pairwise coupling tensor J θ (x) from sequence embedding g θ (x) we introduce the multi-head bilinear form (mhbf) in Appendix B. One can think of the multi-head bilinear form as the L × L self-attention maps of the Transformer's multi-head attention module, but without softmax normalization. When using mhbf for direct prediction, there are K 2 heads, one for every amino acid pair k, l. For the Pfam experiments, we extend the architecture with convolutional layers after the mhbf, where the final convolutional layer has K 2 output channels. We initialize g θ (x) with a Transformer pre-trained with masked language modeling following (Rives et al., 2019) . To evaluate Neural Potts Model energy landscapes, we will focus on proteins with structure in the Protein Data Bank (PDB), using the magnitude of the couplings after APC correction to rank contacts. The protocol is described in Appendix C.2. 

4.1. PFAM CLANS

To study generalization in a controlled setting, we investigate a small set of structurally-related MSAs from the Pfam domain family database (Finn et al., 2016) belonging to the same Pfam clan. We expect that on a collection of related MSAs, information could be generalized to improve performance on low-depth MSAs. Families within a Pfam clan are linked by a distant evolutionary relationship, giving them related but not trivially-similar structure. We obtain contact maps for the sequences in each of the families where a structure is available in the PDB. At test time we input the sequence and compare the generated couplings under the model to the corresponding structure. We compare the NPM to two baselines. The first direct comparison is to an independent Potts model trained directly on the MSA. For the second baseline we construct the "nearest neighbor" Potts model, by aligning each test sequence against all families in the training set, and using the Potts model from the closest matching family. We perform the experiment using a five-fold cross-evaluation scheme, in which we partition the clan's families into five equally-sized buckets. As in standard cross-validation, each bucket will eventually serve as an evaluation set. However, we do not remove the evaluation bucket. Instead, we artificially reduce the number of sequences in the MSAs in the evaluation bucket to a smaller fixed MSA depth. MSAs in the remaining buckets remain unaltered. The goal of this setup is to check the model's ability to infer contacts on artificially limited sets of sequences. Both NPM and the baseline independent Potts model are fit on the reduced set of sequences. Note that while the baseline Potts model uses the reduced MSA of the target directly, NPM is trained on the reduced MSA but evaluated using only the target sequence as input. We train a separate NPM on each of the five cross-evaluation rounds, evaluate on the structures corresponding to the bucket with reduced We ask whether an amortization gain can be realized in two different settings: (i) for sequences the model has been trained on; (ii) for sequences in the test set. We partition the UniRef50 representative sequences into 90% train and 10% test sets, constructing an MSA for each of the sequences. During training, the model is given a sequence from the train set as input, and the NPM objective is minimized using a sample from the MSA of the input sequence. In each training epoch, we randomly subsample a different set of 30 sequences from the MSA to fit the NPM objective. We use ground-truth structures to evaluate the NPM couplings and independent Potts model couplings for contact precision. The dataset is further described in Appendix C.4; and details on the model and training are given in Appendix C.1. The independent Potts model baseline is trained on the full MSA. This means that in setting (i) the NPM and independent Potts models have access to the same underlying MSAs during training. In setting (ii) the independent Potts model is afforded access to the full MSA; however the NPM has not been trained on this MSA and must perform some level of generalization to estimate the couplings. Figure 5 shows a comparison between the NPM predictions and individual Potts models fit from the MSA. The Neural Potts Model is given only the query sequence as input. On top-L/5 long range precision, NPM has better precision than independent Potts models for 22.3% of train and 22.7% of test proteins. We visualize in Fig. 6 example proteins with low MSA-depth where NPM does better than the individual Potts model. For shallow MSAs, the average performance of NPM is higher than the Potts model, suggesting an inductive generalization gain. 

5. RELATED WORK

Recently, protein language modeling has emerged as a promising direction for learning representations of protein sequences that are useful across a variety of tasks. Rives et al. (2019) and Rao et al. (2019) trained protein language models with the masked language modeling (MLM) objective originally proposed for natural languge processing by Devlin et al. (2019) . Alley et al. (2019) , Heinzinger et al. (2019), and Madani et al. (2020) trained models with autoregressive objectives. Transformer protein language models trained with the MLM objective learn information about the underlying structure and function of proteins including long range contacts (Rives et al., 2019; Vig et al., 2020) . This paper builds on the ideas in the protein language modeling literature, introducing the following new ideas: the first is supervision with an energy landscape (defined by a set of sequences) rather than objectives which are defined by a single sequence; the second is to use amortized optimization to fit a single model across many different energy landscapes with parameter sharing; the final is the consideration of the unsupervised contact prediction problem setting rather than the use of representations in a supervised pipeline. Unsupervised structure learning is reviewed in the introduction. The main approach has been to learn a set of constraints from a family of related sequences by fitting a Potts model energy function to the sequences. Our work builds on this idea, but rather than fitting a Potts model to a single family of related sequences, proposes through amortized optimization to fit Potts models across many sequence families with parameter sharing in a deep neural network. Supervised learning has produced breakthrough results for protein structure prediction (Xu, 2018; Senior et al., 2019; Yang et al., 2019) . State-of-the-art methods use supervised learning with deep residual networks on co-evolutionary features derived from the unsupervised structure learning pipeline. While Xu et al. (2020) show that reasonable predictions can be made without co-evolutionary features, their work also shows that these features contribute significantly to the performance of state-of-the-art pipelines. Prior work studying protein language models for contact prediction focuses on the supervised setting. Bepler & Berger (2019) studied pre-training an LSTM on protein sequences and fine-tuning on contact data. Rives et al. (2019) and Rao et al. (2019) studied supervised contact prediction from Transformer protein language models. Vig et al. (2020) found that contacts are represented in Transformer self-attention maps. Our work differs from prior work on structure prediction using protein language models by focusing on the unsupervised structure learning setting. It would be a logical extension of this work to integrate the Neural Potts model into the supervised pipeline.

6. DISCUSSION

This paper explores how a protein sequence model can be trained to produce a local energy landscape that is defined by a set of evolutionarily related sequences for each input. The training objective is cast as an amortized optimization problem. By learning to output the parameters for a Potts model energy function across many sequences, the model may learn to generalize across the sequences. We also formally and empirically investigate the generalization capability of models trained through amortized optimization. We consider the setting of training independent Potts models on the MSA of each sequence, in comparison with training a single model using the objective to predict Potts model parameters for many inputs. Empirically the amortized objective provides an inductive gain when few related sequences are available in the MSA for training the independent Potts model. A number of direct extensions exist for future work, including further investigation of model architecture and parameterization of the energy function by the deep network, use of the amortized models in a supervised pipeline, and combining independent Potts models with amortized couplings. The hidden representations could also be investigated for structure prediction and other tasks using the approaches in the protein language modeling literature. The main contribution of this work is to directly incorporate information from a set of sequences related to the input in the learning objective. It would be interesting to investigate other possible approaches for incorporating this type of supervision into models that aim to learn underlying structure from sequence data. The second equality is the symmetry, the last equality by transposing the bilinear form. From B kl = B lk it follows that U kl V kl = V lk U lk U kl = V lk The last equality is the obvious choice. In the tied parametrization, this simply becomes U kl = U k = V k = V lk such that W kl = U k U l . Once again, note that the dot product now becomes J ij (k, l) = (e i U k )(e j U l ) . We present a Tensor decomposition perspective on this multi-head bilinear form in Appendix B.2. Convolutional layers after multi-head bilinear form As an extended model architecture, we consider having convolutional layers after the multi-head bilinear form (only used for the Pfam experiments). parametrized with a learned interaction matrix B kl ∈ R d×d In this case, rather than having K 2 heads B kl , we now have an arbitrary number of heads F which will become the number of channels in the consecutive convolutional layers: B F = U F V F . We add 1 × 1 convolutional layers having also F channels, and finally K 2 output channels for the last convolutional layer. Weight tying and symmetry considerations of the mhbf do not apply in this model variation.

B.2 TENSOR DECOMPOSITION VIEW ON MULTI-HEAD BILINEAR FORM

We can see the multi-head bilinear form as a tensor decomposition of J, for which we will use Einstein notation to indicate that any pair of indices appearing both in subscript and superscript are summed over their range. Let us write the tensor collecting the U kl matrices as U ∈ R K×K×d×d ; and index into U in the same notation as for U : U kl αr = U kl αr . With α, β ∈ [1 . . . d], r ∈ [1 . . . d ] The same for V. Now the J estimate in the full untied asymmetric case, written as tensor, becomes J kl ij = e α i U kl αr V klr β e β j or the symmetric (U kl αr = V lk αr ) and tied (U ∈ R K×d×d ) version: J kl ij = e α i U k αr U lr β e β j Note that the U, V are shared across proteins, while the embeddings e = g θ (σ) are specific per protein, based on a high-capacity sequence level model.

C EXPERIMENT DETAILS C.1 TRAINING DETAILS

We summarize the precise model architecture and optimization settings in Table 2 . During each NPM training step, for a given input x, M sequences xm are randomly sampled (M=100 or 30, see Table 2 ), for the pseudo-likelihood loss evaluation in Eq. ( 8). Each sequence is selected with probability according to its sequence weight w m . One can think of these M sampled sequences as similar to a minibatch. Note that to compute the independent Potts model baseline, the Potts model is computed without any downsampling of the MSA. Additionally, in the Pfam experiments the loss term for family n in Eq. ( 8) is upweighted with a factor M eff (n), which places more weight on the well-formed, deep MSAs and discounts the shallower MSAs. In both the Pfam and UniRef experiments, we enforce a max sequence length of 512 via random contiguous crops of positions.

C.2 VALIDATION DETAILS

To compute precisions, we convert the pairwise couplings J ∈ R L×K×L×K to an L × L pairwise coupling score by (1) zeroing all positions in J corresponding to gap characters, (2) computing the magnitude via Frobenius norm over the K × K matrix J ij for every pair of positions i, j, and (3) applying the Average Product Correction (Dunn et al., 2008) . True contacts are defined as pairwise distances less than or equal to 8 Angstroms. Precision is calculated as the true positive fraction of the top L, L/2 , or L/5 predicted contacts. Additional to precision, the Area Under the (-cov [20, 80] ) specified by a call to HHfilter, applied to each MSA. The fraction of MSAs with depth ≤ 10 is 19% (38% when a coverage of 80% is specified), while the fraction of MSAs with depth ≤ 100 is 30% (55% when a coverage of 80% is specified) . each family, a single structure is selected as target, using the pdbmap included in Pfam. NPM's contact predictions are made using only the sequence belonging to the target structure. To compute the independent Potts model for a given family in the evaluation bucket, the depth-reduced MSA is aligned to the sequence from the target structure, and the Potts model is computed without further downsampling. As an additional baseline, we predict contacts for validation sequences using the Potts model of the "Nearest Neighbor" family in the train set. For a given validation sequence, we calculate "nearness" to all train families via calls to HHalign given the sequence and the train family's Pfam seed alignment as input. We select the family with the highest HHalign probability score as the nearest neighbor. The nearest neighbor prediction is generated as follows: (1) the validation sequence is aligned to the selected train family's MSA; (2) an independent Potts model is fit to the selected train family's MSA, yielding a predicted contact map for the train family; (3) the rows and columns of the predicted contact map that align to the validation sequence are extracted to construct a prediction for the validation sequence.

C.4 UNIREF50 TRAINING DATA AND SETUP

For the experiments in Section 4.2, we retrieve the UniRef-50 (Suzek et al., 2007) database dated 2018-03. The UniRef50 clusters are partitioned randomly in 90% train and 10% test sets. For all sequences, we construct MSAs using HHblits (Steinegger et al., 2019) against the UniClust30_2017_10 database. HHblits is run using the default settings, for 3 iterations with an e-value of 0.001. It is important to note that given this MSA generation procedure, validation sequences can be included in MSAs of train sequences. However, we are guaranteed that validation sequences are not trained on as input to NPM.

Evaluation of contact precision

We use contact precision as a proxy to measure unsupervised structure learning in the underlying Potts model. To define a set of structures for evaluation, we collect structures from the PDB, and assign them to either the training sequences or test sequences. This allows us to separately examine performance of NPM on sequences from its training set, and sequences from its test set. Note, that



Figure 1: (a) Standard Potts model requires constructing an MSA and optimizing parameters W . (b) Neural Potts Model (NPM) predicts W in a single feedforward pass from a single sequence.

for sparse coding, Li & Malik (2016); Andrychowicz et al. (2016); Finn et al. (2017) for meta-learning, and Marino et al. (2018); Kim et al. (2018) for posterior optimization.

Figure 3: Contact prediction precision on Pfam families from the NADP Rossmann clan, at different levels of depth reduction. Columns show (from left to right) short, medium and long-range precision for top-L threshold. Across the metrics, NPM outperforms the independent Potts model trained on the shallowest MSAs, as well as the Nearest Neighbor Potts model baseline.

Figure 6: Examples where NPM outperforms the independent Potts model fit directly on the MSA. NPM top-L/5 LR contact prediction (lower diagonal, red) compared to the independent Potts model prediction (upper diagonal, blue). All ground truth contacts are indicated in black. True and false hits are indicated with dots and crosses, respectively.

Figure7: A random sample of 1000 MSAs for 1000 sequences in UniRef50 were analyzed. The graph shows cumulative density plots with MSA depth in log-scale on the x axis, for different query sequence coverage requirements(-cov [20, 80]) specified by a call to HHfilter, applied to each MSA. The fraction of MSAs with depth ≤ 10 is 19% (38% when a coverage of 80% is specified), while the fraction of MSAs with depth ≤ 100 is 30% (55% when a coverage of 80% is specified) .

Inductive generalization gain (illustration with a 1D loss landscape). Ŵ * is the standard Potts model, estimated on the finite observed MSA M. Though it minimizes the training objective, it does not achieve perfect generalization performance. However the Neural Potts Model W θ (x) can generalize better than Ŵ * through transfer learning from related samples, guided by the inductive bias of the model. We expect this especially when the estimate Ŵ

;Ha et al. (2016);Bertinetto et al. (2016). This is related to multi-task learning, but with a major difference that (B) the solo optimization is over a single tensor W in the Potts model, but (C) a function f θ in a learning problem. In the amortized/multi-task setting, the distribution over query sequences x in (B) NPM plays the role that different related tasks play in (C) MTL. In the NPM setting (B), W θ takes x explicitly as argument, versus (C) MTL typically just has a separate output head per task.

annex

MSAs, and show averages and standard deviations across rounds. Further details are provided for model training in Appendix C.1 and for the Pfam dataset in Appendix C.3.Figure 3 shows the resulting contact prediction performance on the 181 families in the NADP Rossmann clan, with additional results on the P-loop NTPase, HTH, and AB hydrolase clans in Appendix D Fig. 9 . We initialize a 12-layer Transformer with protein language modeling pre-training. Because of the small dataset size, we keep the weights of the base Transformer g θ frozen and only finetune the final layers. As a function of increasing MSA depth, contact precision improves for both NPM and independent Potts models. For the shallowest MSAs, NPM has a higher precision relative to the independent Potts models. The advantage at low MSA depth is most pronounced for long range contacts, outperforming independent Potts models up to MSA depth 1000. These experiments suggest NPM is able to realize an inductive gain by sharing parameters in the pre-trained base model as well as the fine-tuned final layers and output head. Figure 4 shows training trajectories. We observe near-monotonic decrease of the amortized pseudo-likelihood loss (Eq. ( 7)) on the MSAs in the evaluation set, and increase of the top-L long range contact precision. This indicates that improving the NPM objective improves the unsupervised contact precision across the reduced-depth MSAs. Furthermore we see expected overfitting for smaller MSA depth: better training loss but worse contact precision.Additionally, we assess performance of different architecture variants: direct prediction with the multi-head bilinear form (always using symmetry), with or without tied projections, and addition of convolutional layers after the multi-head bilinear form. The variants are described in detail in Appendix B. We find in Appendix D Fig. 8 that addition of convolutional layers after the multi-head bilinear form performs best; for the variant without convolutional layers, the head without weight tying performs best. We now perform an evaluation in the more realistic setting of the UniRef50 dataset (Suzek et al., 2007) . First we examine MSA depth across UniRef50 (Suzek et al., 2007) . Appendix C.4 Fig. 7 finds that 19% of sequences in UniRef50 have MSAs with fewer than 10 sequences. (38% when a minimum query sequence coverage of 80% is specified). 

4.2. UNIREF50

Optimize the regularized loss in Eq. ( 8) end while

B MODEL ARCHITECTURE: MULTI-HEAD BILINEAR FORM FOR PAIRWISE COUPLINGS

In this Section, we describe the model architecture to compute a four-dimensional pairwise coupling tensor J θ (x) from sequence embedding g θ (x).

B.1 MULTI-HEAD BILINEAR FORM

We write sequence length L and amino acid vocabulary K = 21. The single site potentials h ∈ R L×K , and the pairwise couplings are a four-dimensional tensor:We start with a sequence-level model to produce the embedding e of the sequence (typically final hidden layer output): e = g θ (x) ∈ R L×d . The estimator for single-site potential h θ (x) is a linear projection layer on the embedding; h θ (x) = g θ (x)P h with P h ∈ R d×K . Now we discuss how to parametrize the estimator J θ (x) ∈ R L×K×L×K .Multi-head bilinear form for direct prediction We introduce a multi-head bilinear form (mhbf) on the embedding e; i.e. for every pair k, l of amino acids we have a bilinear form, parametrized with a learned interaction matrix B kl ∈ R d×d connecting the hidden states at positions e i , e j ∈ R 1×d . So we compute the K 2 bilinear forms for amino acid pairs (k, l) between L × L position pairs (i, j): J ij (k, l) = e i B kl e j . We always use a low-rank decomposition B kl = U kl V kl with both U kl , V kl ∈ R d×d , so the bilinear form becomes the inner product in the lower-dimensional space with d the projection dimension: (e i U kl )(e j V kl ) . We can interpret this as an inner product of embeddings i, j after linear projection to a space specific to amino acid pair (k, l). This low-rank multi-head bilinear form is similar to the multi-head attention mechanism introduced in Vaswani et al. ( 2017), but without softmax normalization.Notation-wise, our parameters θ include the parameters of the transformer that produces the embedding and the components of the interaction matrix {U kl , V kl }.Direct prediction: tied projection One way to reduce the number of parameters in the multi-head bilinear form, is for the low-rank decompositions of the K 2 heads B kl to share their decomposition per k, l. Instead we can share/tie the projection matrices between amino acids k, l: U kl = U k and V kl = V l , such that head B kl = U k V l . Note that the dot product in this case is after a linear projection specific to single-site amino acid k and l separately; J ij (k, l) = e i B kl e j = (e i U k )(e j V l ) .Direct prediction: Symmetry We can or should parametrize the estimator of J to be symmetrical against interchanging both i,j and k,l: J ij (k, l) = J ji (l, k), i.e. no difference between the order of considering interaction between AA k at position i with AA l at position j. This does not mean symmetry of each interaction matrix! We ask that For the independent Potts model baselines in all experiments, we use CCMpred (Seemayer et al., 2014) , a GPU implementation of pseudolikelihood maximization (Balakrishnan et al., 2011) . The coupling matrix J from the independent Potts model is processed in the same way following steps (1-3) described above.

C.3 PFAM TRAINING DATA AND SETUP

Data Selection. We use the Pfam database (Finn et al., 2016) version 28.0. All MSAs in the HTH (n=217), P-loop NTPase (n=198), NADP Rossmann (n=181), and AB hydrolase (n=67) clans were parsed from the multiple alignment file Pfam-A.full. We apply two preprocessing steps to all MSAs. First, for speed, we only load up to a maximum of 100k sequences from each MSA. Next, we apply HHfilter, from the HHSuite3 (Steinegger et al., 2019) toolset, with all default settings, to each MSA. We find that filtering improves contact prediction accuracy of the independent Potts model baseline.Dataset splits. We perform the experiment using a five-fold cross-evaluation scheme, in which we partition the clan's families into five equally-sized buckets. As in standard cross-validation, each bucket will eventually serve as an evaluation set. However, we do not remove the evaluation bucket, but artificially reduce the number of sequences in the MSAs in the evaluation bucket to a small fixed MSA depth (=purging the MSA). All Pfam experiments are repeated 5 times, each with a different selection for the reduced bucket. In our figures, we plot average results, with confidence interval bounds defined by the standard deviation across the five-fold cross-evaluation.During NPM training, we iterate over the the set of MSAs in the four buckets that have not been reduced, as well as the reduced bucket. At a given training step, we randomly select a sequence x within an MSA for use as input to NPM. This selection is likely to return a sequence with inserted gap characters. We drop these gap characters and their corresponding columns in the MSA. Then we randomly subsample 100 sequences x from the MSA to fit the NPM objective. The procedure is described in more detail in Appendix C.1.Evaluation During evaluation, we assess the NPM and the independent Potts model via a contact prediction task (described in previous subsection), on the families in the evaluation bucket. For the structures are used only to evaluate unsupervised contact prediction performance of the model; the model is never trained on structures.We query the Protein Data Bank (PDB) to obtain a list of all protein structures with a resolution less than 2.5 Å, a length greater than 40 residues, and a submission date before May 1, 2020. We search each pdb entry for hits against the sequences in the training and test sets for NPM respectively. If the PDB entry retrieves hits only to training sequences we assign it to the training-sequences group. If the PDB entry retrieves hits only to test sequences we assign it to the test-sequences group. Any PDB entry which hits both training and test sequences or neither, is discarded. To perform the search we use the MMSeqs2 software suite (Steinegger & Söding, 2018) using the default settings with 50% sequence identity at 80% target coverage. We then cluster each of the two groups of PDB entries at 50% sequence similarity, resulting in a dataset of 11040 structures assigned to train sequences and 211 structures assigned to validation sequences. MSA construction for the PDB entries precisely follows the procedure for UniRef50 (first paragraph); the method for contact prediction from the model couplings (for NPM or the independent Potts model) is described in Appendix C.2. We show precisions at fixed top-L threshold, while on the x-axis we vary sequence separation range and two levels of MSA depth reduction (10 and 1000). Standard deviations over the five-fold cross-evaluation are shown. For the direct multi-head bilinear form (mhbf) prediction (tied or untied), we found an improvement from using U, V projection dimension 512 rather than 128. Other hyper-parameters follow Table 2 . 

