ENFORCING PREDICTIVE INVARIANCE ACROSS STRUCTURED BIOMEDICAL DOMAINS

Abstract

Many biochemical applications such as molecular property prediction require models to generalize beyond their training domains (environments). Moreover, natural environments in these tasks are structured, defined by complex descriptors such as molecular scaffolds or protein families. Therefore, most environments are either never seen during training, or contain only a single training example. To address these challenges, we propose a new regret minimization (RGM) algorithm and its extension for structured environments. RGM builds from invariant risk minimization (IRM) by recasting simultaneous optimality condition in terms of predictive regret, finding a representation that enables the predictor to compete against an oracle with hindsight access to held-out environments. The structured extension adaptively highlights variation due to complex environments via specialized domain perturbations. We evaluate our method on multiple applications: molecular property prediction, protein homology and stability prediction and show that RGM significantly outperforms previous state-of-the-art baselines.

1. INTRODUCTION

In many biomedical applications, training data is necessarily limited or otherwise heterogeneous. It is therefore important to ensure that model predictions derived from such data generalize substantially beyond where the training samples lie. For instance, in molecule property prediction (Wu et al., 2018) , models are often evaluated under scaffold split, which introduces structural separation between the chemical spaces of training and test compounds. In protein homology detection (Rao et al., 2019) , the split is driven by protein superfamily where entire evolutionary groups are held out from the training set, forcing models to generalize across larger evolutionary gaps. The key technical challenge is to be able to estimate models that can generalize beyond their training data. The ability to generalize implies a notion of invariance to the differences between the available training data and where predictions are sought. A recently proposed approach known as invariant risk minimization (IRM) (Arjovsky et al., 2019) seeks to find predictors that are simultaneously optimal across different such scenarios (called environments). Indeed, one can apply IRM with environments corresponding to molecules sharing the same scaffold (Bemis & Murcko, 1996) or proteins from the same family (El-Gebali et al., 2019) (see Figure 1 ). However, this is challenging since, for example, scaffolds are structured objects and can often uniquely identify each example in the training set. It is not helpful to create single-example environments as the model would see any variation from one example to another as scaffold variation. In this paper, we propose a regret minimization algorithm to handle both standard and structured environments. The basic idea is to simulate unseen environments by using part of the training set as held-out environments E e . We quantify generalization in terms of regret -the difference between the losses of two auxiliary predictors trained with and without examples in E e . This imposes a stronger constraint on φ and avoids some undesired representations admitted by IRM. For the structured environments like molecular scaffolds, we simulate unseen environments by perturbing the representation φ. The perturbation is defined as the gradient of an auxiliary scaffold classifier with respect to φ. The difference between the original and perturbed representation highlights the scaffold variation to the model. Its associated regret measures how well a predictor trained without perturbation generalizes to the perturbed examples. The goal is to characterize the scaffold variation without explicitly creating an environment for every possible scaffold. (Wu et al., 2018) , there are 1600 scaffold environments with 75% of them having a single example. Our methods are evaluated on real-world datasets such as molecule property prediction and protein classification. We compare our model against multiple baselines including IRM, MLDG (Li et al., 2018a) and CrossGrad (Shankar et al., 2018) . On the QM9 dataset (Ramakrishnan et al., 2014) 

2. RELATED WORK

Generalization challenges in biomedical applications The challenges of generalization have been extensively documented in this area. For instance, Yang et al. (2019) ; Rao et al. (2019) ; Hou et al. (2018) have demonstrated that state-of-the-art models exhibit drop in performance when tested under scaffold or protein family split. De facto, the scaffold split and its variants (Feinberg et al., 2018) are used so commonly in cheminformatics as they emulate temporal evaluation adopted in pharmaceutical industry. Therefore, the ability to generalize to new scaffold or protein family environments is the key for practical usage of these models. Moreover, input objects in these domains are typically structured (e.g., molecules are represented by graphs (Duvenaud et al., 2015; Dai et al., 2016; Gilmer et al., 2017) ). This characteristic introduces unique challenges with respect to the environment definition for IRM style algorithms. Invariance Prior work has sought generalization by enforcing an appropriate invariance constraint over learned representations. For instance, domain adversarial network (DANN) (Ganin et al., 2016; Zhao et al., 2018) enforces the latent representation Z = φ(X) to have the same distribution across different environments E (i.e, Z ⊥ E). However, this forces predicted label distribution P (Y |Z) to be the same across all the environments (Zhao et al., 2019) . Long et al. (2018) ; Li et al. (2018c); Combes et al. (2020) extends the invariance criterion by conditioning on the label in order to address the label shift issue of DANN. Invariant risk minimization (IRM) (Arjovsky et al., 2019) seeks a different notion of invariance. Instead of aligning distributions of Z, IRM requires that the predictor f operating on Z = φ(X) is simultaneously optimal across different environments. The associated independence is Y ⊥ E | Z. Various work (Krueger et al., 2020; Chang et al., 2020) has sought to extend IRM. We focus on the structured setting, where most of the environments can uniquely specify X in the training set. As a result, E would act similarly to X. In the extreme case, the IRM principle reduces to Y ⊥ X | Z, which is not the desired invariance criterion. We propose to address this issue by introducing domain perturbation to adaptively highlight the structured variation.

Domain generalization

These methods seek to learn models that generalize to new domains (Muandet et al., 2013; Ghifary et al., 2015; Motiian et al., 2017; Li et al., 2017; 2018b) . Domain generalization methods can be roughly divided into three categories: domain adversarial training (Ganin et al., 2016; Tzeng et al., 2017; Long et al., 2018) , meta-learning (Li et al., 2018a; Balaji et al., 2018; Li et al., 2019a; b; Dou et al., 2019) and domain augmentation (Shankar et al., 2018; Volpi et al., 2018) . Our method resembles meta-learning based methods in that we create held-out environments to simulate domain shift during training. However, our objective seeks to reduce the regret between predictors trained with or without access to the held-out environments. Existing domain generalization benchmarks assume that each domain contains sufficient amounts of data. We focus on a different setting where most of the environments contain only few (or single) examples since they are defined by structured descriptors. This setting often arises in chemical and biological applications (see Figure 1 ). Similar to data augmentation method in Shankar et al. (2018) , our structured RGM also creates perturbed examples based on domain-guided perturbations. However, our method operates over learned representations since our inputs are discrete. Moreover, the perturbed examples are only used to regularize the feature extractor φ via the regret term.

3. REGRET MINIMIZATION

To introduce our method, we start with a standard setting where the training set Arjovsky et al., 2019) . Each environment E i consists of examples (x, y) randomly drawn from some distribution P i . Assuming that new environments we may encounter at test time exhibit similar variability as the training environments, our goal is to train a model that generalizes to such new environments E test . Suppose our model consists of two components f • φ, where the predictor f operates on the feature extractor φ. D is comprised of n environments E = {E 1 , • • • , E n } ( Let L e (f • φ) = (x,y)∈Ee (y, f (φ(x))) be its empirical loss in environment E e and L(f • φ) = e L e (f • φ). IRM learns φ and f such that f is simultaneously optimal in all training environments: min φ,f L(f • φ) s.t. ∀e : f ∈ arg min h L e (h • φ) One possible way to solve this objective is through Lagrangian relaxation: min φ,f L(f • φ) + e λ e L e (f • φ) -min h L e (h • φ) The regularizer L e (f • φ) -min h L e (h • φ) measures the performance gap between f and the best predictor ĥ ∈ F e (φ) = arg min h L e (h • φ) specific to environment E e . Note that both f and ĥ are trained and evaluated on examples from environment E e . This motivates us to replace the regularizer with a predictive regret. Specifically, for each environment E e , we define the associated regret R e (φ) as the difference between the losses of two auxiliary predictors trained with and without access to examples (x, y) ∈ E e : R e (φ) = L e (f -e • φ) -min h∈F L e (h • φ) = L e (f -e • φ) -L e (f e • φ) where the two auxiliary predictors are obtained from (assuming F is bounded and closed): f e ∈ F e (φ) = arg min h∈F L e (h • φ) f -e ∈ F -e (φ) = arg min h∈F k =e L k (h • φ) The oracle predictor f e is trained on environment E e , while f -e uses the rest of the environments E\{E e } for training but is tested on E e . Note that R e (φ) does not depend on the predictor f we are seeking to estimate; it is a function of the representation φ as well as the two auxiliary predictors f -e and f e . For notational simplicity, we have omitted R e (φ)'s dependence on f -e and f e . Since both predictors are evaluated on the same set of training examples in E e , we immediately have Proposition 1. The regret R e (φ) is always non-negative for any representation φ. The proof is straightforward since f e is the minimizer of L e (f • φ) and both f e and f -e are drawn from the same parametric family F. The overall regret R(φ) = e R e (φ) expresses our stated goal of finding a representation φ that generalizes to each held-out environment. Our regret minimization (RGM) objective regularizes the empirical loss with a regret term weighted by λ: L RGM = L(f • φ) + λ e R e (φ)

3.1. COMPARISON WITH IRM

Compared to IRM, the proposed RGM objective imposes a stronger constraint on φ since f -e is not trained on E e . To show this formally, let F e (φ), F -e (φ) be the set of optimal predictors in E e and E\{E e } respectively as defined in Eq.( 4). Since R e (φ) = 0 ⇔ f -e ∈ F e (φ) and f -e is chosen arbitrarily from F -e (φ), the constrained form of the RGM objective can be stated as  min φ,f L(f • φ) s.t. ∀e : F -e (φ) ⊆ F e (φ) ϕ 2 + + + + + + + + + + - - - - - -- - - - E 1 E 2 h ∈ F -2 (ϕ) f IRM ∉ F 2 (ϕ) Figure 2: A counterexample illustrating that Φ IRM ⊆ Φ RGM . The environments are generated by different translations of X 1 . For the identity mapping φ(X) = (X 1 , X 2 ) and the true hypothesis is I[X 2 > 0]. There exists a predictor f IRM which is simultaneously optimal in all environments. In contrast, φ is not feasible under RGM because there is a linear classifier h ∈ F -2 (φ) that is optimal in environment E 1 but performs poorly in environment E 2 . The analogous IRM constraints are f ∈ ∩ e F e (φ) and ∩ e F e (φ) = ∅. Suppose both IRM and RGM constraints are feasible and let L * IRM , L * RGM be their optimal loss respectively. Consider the set of optimal features under both objectives: Φ IRM = {φ | min f ∈∩eFe(φ) L(f • φ) = L * IRM , ∩ e F e (φ) = ∅} (7) Φ RGM = {φ | min f ∈F L(f • φ) = L * RGM , ∀e : F -e (φ) ⊆ F e (φ)} (8) Proposition 2. Assuming two environments, if L * RGM = L * IRM , then Φ RGM ⊆ Φ IRM . The converse Φ IRM ⊆ Φ RGM does not hold in general. While limited to two environments, the proposition suggests that RGM imposes stronger constraints on φ. Figure 2 shows a counterexample illustrating that Φ IRM ⊆ Φ RGM . Suppose there are two environments generated by translation of X 1 and the true hypothesis is I[X 2 > 0]. The identity mapping φ(X) = (X 1 , X 2 ) is not translation invariant, but φ ∈ Φ IRM because there exists a predictor f IRM that is simultaneously optimal in all environments. On the other hand, φ is not feasible under RGM because there is a linear classifier h ∈ F -2 (φ) that is optimal in E 1 but suboptimal in E 2 , violating the RGM constraint F -2 (φ) ⊆ F 2 (φ). Thus φ ∈ Φ RGM . To see why it would be helpful to add a stronger constraint on φ, consider the following data generation process where the environment e can be inferred from input x alone: p(x, y, e) = p(e)p(x|e)p(y|x, e); p(y|x, e) = p(y|x, e(x)) For molecules and proteins, this assumption is often valid because the environment labels (scaffolds, protein families) typically depend on x only. We call φ label-preserving if it retains all the information about the label: p(y|φ(x)) = p(y|x, e). Such representation may not generalize to new environments given the dependence on e through φ. However, we can show that for any label-preserving φ, its associated ERM optimal predictor also satisfies the IRM constraints: Proposition 3. For any label-preserving φ with p(y|φ(x)) = p(y|x, e), its associated ERM optimal predictor f * satisfies the IRM constraint. Moreover, if φ ∈ Φ IRM , f * • φ is optimal under IRM. While IRM constraints are vacuous for any label-preserving φ, this is not necessarily the case with RGM constraints. Consider, for example, the counterexample in Figure 4 . The identity mapping φ(X) = (X 1 , X 2 ) is label-preserving since it retains all the input information. However, φ is infeasible under RGM.

3.2. STRUCTURED ENVIRONMENTS

Now let us consider a more challenging setting, where the environments {E k } are structured (i.e., k is a structured object rather than an integer). Formally, the training set comes in the form D = {(x i , y i , s i )}, where s i is the structured environment label of (x i , y i ) ∈ E si . For instance, in molecule property prediction, s i is defined as the Murcko scaffold (i.e., subgraph) of molecule x i . It is hard to model scaffolds as standard environments because they are structured descriptors and often uniquely identify each molecule in the training set (Figure 1 ). When an environment has only one molecule, the model cannot decide which subgraph of that molecule is the right scaffold. Thus, creating single-example environments is not helpful for domain generalization. Alternatively, we can describe scaffold variation by perturbation in the representation φ. The idea is to create a perturbed instance xi for each example (x i , y i , s i ) so that the difference between x i and xi highlights how scaffold information has changed in the representation. Specifically, the perturbation Construct perturbed examples Be from B e via gradient perturbation (see Eq.( 10)).

5:

Compute empirical loss L(f • φ) on B e . 6: Compute auxiliary predictor loss L -e (f -e • φ) on B -e . 7: Compute oracle predictor losses L e (f e • φ) and L e ( fe • (φ + δ)) on B e and Be .

8:

Compute regret terms R e (φ), R e (φ + δ) on B e and Be . 9: end for δ(x i ) is defined through a parametric scaffold classifier g built on top of the representation φ. 1 The associated scaffold classification loss is (s i , g(φ(x i ))). Given that our inputs are discrete, we define the perturbation δ as the gradient with respect to the continuous representation φ: φ( xi ) := φ(x i ) + δ(x i ) = φ(x i ) + α∇ z (s i , g(z))| z=φ(xi) ( ) where α is a step size parameter. The perturbation is specifically designed to contain less information about the scaffold s i , and we require that the model should not be affected by this variation in the representation. Since these perturbations introduce additional simulated test scenarios that we wish to generalize to, we propose to regularize our model also based on regret associated with perturbed inputs. Similar to Eq.( 3), the regret corresponding to perturbed inputs is defined as R e (φ + δ): R e (φ + δ) = L e (f -e • (φ + δ)) -min h L e (h • (φ + δ)) L e (h • (φ + δ)) = (xi,yi)∈Ee y i , h(φ(x i ) + δ(x i )) This introduces a new oracle predictor fe = arg min h L e (h • (φ + δ)) for each environment E e (see Figure 3a ). Note that f -e is the same auxiliary predictor as before. It minimizes a separate objective L -e (f -e • φ), which does not include the perturbed examples. The structured RGM (SRGM) objective L SRGM augments the basic RGM with additional regret terms as well as the scaffold classification loss L g (g • φ): L SRGM = L(f • φ) + λ g L g (g • φ) + λ e ψ∈{0,δ} R e (φ + ψ) (13) L g (g • φ) = (xi,yi,si)∈D s i , g(φ(x i )) The forward pass of SRGM is shown in Algorithm 1. Since s is a structured object with a large number of possible values, we train the classifier g with negative sampling (Figure 3b ). Note that φ is also updated to partially optimize L g . This is necessary to ensure that the scaffold classifier operating on φ has enough information to introduce a reasonable gradient perturbation δ(x). This trade-off keeps some scaffold information in φ while ensuring, via the associated regret terms, that this information is not strongly relied upon. The effect of this design choice is studied in our experiments. f -e f e Regret R e (ϕ) In the backward pass, the gradient of L e (f e • φ) goes through a gradient reversal layer (Ganin et al., 2016) which negates the gradient during back-propagation. ϕ 2 + + + + + + + + + + - - - - - -- - - - E 1 E 2 h ∈ F -2 (ϕ) f IRM ∉ F 2 (ϕ) Environment E i minimatch Environment E j minimatch Environment E e minimatch

3.3. OPTIMIZATION

The standard RGM objective in Eq.( 5) can be viewed as finding a stationary point of a multiplayer game between f , φ as well as the auxiliary predictors {f -e } and {f e }. Our predictor f and representation φ find their best response strategies by minimizing min f,φ L(f • φ) + λ e L e (f -e • φ) -L e (f e • φ) while the auxiliary predictors minimize  This multi-player game can be optimized by stochastic gradient descent. Since f e and φ optimizes L e (f e • φ) in opposite directions, we introduce a gradient reversal layer (Ganin et al., 2016) between φ and f e . This allows us to update all the players in a single forward-backward pass (see Figure 4 ). In each step, we simultaneously update all the players with learning rate η: f ← f -η∇ f L(f • φ) φ ← φ -η∇ φ L(f • φ) -ηλ e ∇ φ R e (φ) f -e ← f -e -η∇L -e (f -e • φ) f e ← f e -η∇L e (f e • φ) ∀e where L -e (f -e • φ) = k =e L k (f -e • φ). In each step, we sample minibatches B 1 , • • • , B n from each environment E 1 , • • • , E n . The loss L(f • φ) is computed over all the minibatches k B k , while L -e (f -e • φ) is computed over minibatches B -e = k =e B k . The regret term R e (φ) is evaluated based on examples in B e only. For structured RGM, its optimization rule is analogous to RGM, with additional gradient updates for the oracle predictors fe and scaffold classifier g (see Appendix A.4). While the perturbation δ is defined on the basis of φ and g, we do not include the dependence during back-propagation as incorporating this higher order gradient does not improve our empirical results.

4. EXPERIMENTS

Our methods (RGM and SRGM) are evaluated on real-world applications such as molecular property prediction, protein homology and stability prediction. Our baselines include: • Standard empirical risk minimization (ERM) trained on aggregated environments; • Domain adversarial training methods including DANN (Ganin et al., 2016) and CDAN (Long et al., 2018) , which seek to learn domain-invariant features; • IRM (Arjovsky et al., 2019) requiring the model to be simultaneously optimal in all environments; • MLDG (Li et al., 2018a) , a meta-learning method which simulates domain shift by dividing training environments into meta-training and meta-testing; • CrossGrad (Shankar et al., 2018) which augments the training set with domain-guided perturbations of inputs. Since our inputs are discrete, we perform perturbation on the representation instead. These methods fall into two categories. SRGM and CrossGrad are structured methods as they can leverage the structural information of the environment (e.g., scaffold). RGM and other methods are categorical methods since they do not utilize the structure and simply treat each environment as a set.

Data

The training data consists of {(x i , y i , s i )}, where x i is a molecular graph, y i is its property and s i is its scaffold. We adopt four datasets from the MoleculeNet benchmark (Wu et al., 2018 ): • QM9 is a regression dataset of 134K organic molecules with up to 9 heavy atoms. Each molecule is labeled with 12 quantum mechanical properties. • HIV is a classification dataset of 42K molecules. Each molecule is associated with a binary label indicating whether it is an HIV inhibitor. • Tox21 is a classification dataset of 8.8K molecules. Each compound has 12 binary labels for toxicity measurements. • The blood-brain barrier penetration (BBBP) dataset contains 2K molecules. Each molecule is labeled with a binary permeability label.

Data split

To test whether a model generalizes to new domains, it is important to create a test set that is distributionally distinct from the training set. Scaffold split (Wu et al., 2018 ) is a common framework for this purpose. Molecules are clustered based on its Bemis-Murcko scaffold (Bemis & Murcko, 1996) and a random subset of scaffolds are selected into a test set. However, this approach degenerates to random split when most scaffold clusters contain only one molecule (see Figure 1 ). To address this issue, Feinberg et al. (2019) proposed molecular weight split, where test molecules are much bigger than molecules in the training set. While this creates strong structural distinction between the training and test sets, it is not as realistic as the scaffold split. Given these observations, we propose a variant of scaffold split called scaffold complexity split. We define the complexity of a scaffold as the number of cycles in the scaffold graph. Specifically, we put a scaffold in the test set if its scaffold complexity is greater than τ and the training set if it is less than τ . We set τ = 2 for QM9 and τ = 4 for other datasets. As shown in Figure 5 , this forces the test scaffolds to be structurally different from the training scaffolds. It is also more realistic than the molecular weight split since the molecular weight distribution of training and test sets are similar. Model The molecule encoder φ is a graph convolutional network (Yang et al., 2019) which translates a molecular graph into a continuous vector. The predictor f is a two-layer MLP that takes φ(x) as input and predicts the label. The scaffold classifier g is also a two-layer MLP trained by negative sampling since scaffold is a combinatorial object with a large number of possible values. Specifically, for a given molecule x i with scaffold s i , we randomly sample K other molecules and take their associated scaffolds {s k } as negative classes. Details of model architecture and hyper-parameters are discussed in the appendix. Results Following Wu et al. (2018) , we report mean absolute error (MAE) for QM9 and AUROC for other datasets. All the results are averaged across five independent runs. Our results on the QM9 dataset are shown in Table 1 . RGM outperforms other categorical methods and demonstrates clear improvement on six properties (mu, alpha, U0, U, H, G). SRGM outperforms all baselines on seven properties, with a significant error reduction on R2, U0, U, H and G (3-10%). Compared to RGM, SRGM performs better on all properties except mu and alpha. On the three classification datasets, SRGM also achieves state-of-the-art compared to all the baselines (see Table 2 ). These results confirm the advantage of exploiting the structure of environments.

4.2. PROTEIN HOMOLOGY PREDICTION

Data The protein homology dataset (Fox et al., 2013; Rao et al., 2019) consists of tuples {(x i , y i , s i )}, where x i is a protein represented as sequence of amino acids, y i its fold label and s i its superfamily label. The task is to predict the fold label y i . There are 1195 fold classes and 1823 protein superfamilies in total. Around 1200 superfamilies have less than 10 instances in the training set. Data split Provided by Rao et al. (2019) , the dataset consists of 12K instances for training, 736 for validation and 718 for testing. The dataset is split based on protein superfamilies. As a result, proteins in the test set are structurally distinct from the training set, requiring models to generalize across large evolutionary gaps. Model Our protein encoder φ is a pre-trained BERT model (Rao et al., 2019) . To generate a sequence-length invariant protein embedding, we simply take the mean of all the vectors output by BERT. The predictor f is a linear function that takes φ(x) as input and predicts its fold class. The superfamily classifier g is a two-layer MLP. The hyperparameters are listed in the appendix. 

4.3. ABLATION STUDY OF SRGM

Updating φ for L g In section 3.2, we mentioned that the feature extractor φ is updated to optimize the scaffold (or superfamily) classification loss L g . To study the effect of this design choice, we evaluate a variant of SRGM called SRGM-detach, where φ is not updated to optimize the scaffold classification loss. As shown in Table 2 (right), the performance of SRGM-detach is worse than SRGM across the four datasets. This is because the scaffold classifier performs better in SRGM and the gradient δ(x) reflects the change of scaffold information more accurately.

Random perturbation

In Table 3 , we report the performance of SRGM under random perturbation on the QM9 dataset. Random perturbation performs significantly worse for most of the properties. This shows the importance of the scaffold classifier in SRGM.

Model complexity

To study how the model complexity of φ affects the performance of SRGM, we train SRGM under different number of graph convolutional layers on the QM9 dataset. As shown in Table 4 , SRGM performs the best when there are three graph convolutional layers, which is adopted in all experiments. In short, SRGM underfits the data when the model is too simple (layer=2) and overfits when the model is too complex (layer=4). A TECHNICAL DETAILS  F 2 (φ * ) = F -1 (φ * ) ⊆ F 1 (φ * ) F 1 (φ * ) = F -2 (φ * ) ⊆ F 2 (φ * ) Therefore F 1 (φ * ) = F 2 (φ * ). * ) = L(f • φ * ), the above inequality implies L(f • φ * ) ≤ min h∈F L(h • φ * ) = L * RGM = L * IRM Thus, f • φ * is Given any label-preserving representation φ(x), its ERM optimal predictor is f * (φ(x)) = arg min f E y|φ(x) (y, f (φ(x))) To see that f * is ERM optimal, consider min f E e E x|e E y|x,e (y, f (φ(x))) ≥ E e E x|e min f E y|x,e (y, f (φ(x))) (20) = E e E x|e min f E y|φ(x) (y, f (φ(x))) (21) = E e E x|e E y|φ(x) (y, f * (φ(x))) where Eq.( 21) holds because φ(x) is label-preserving. Note that f * satisfies the IRM constraint because it is simultaneously optimal across all environments: ∀e : min fe E x|e E y|x,e (y, f e (φ(x))) ≥ E x|e min fe E y|x,e (y, f e (φ(x))) (23) = E x|e min f E y|φ(x) (y, f (φ(x))) (24) = E x|e E y|φ(x) (y, f * (φ(x))) Moreover, if φ ∈ Φ IRM is an optimal representation, f * • φ is an optimal solution of IRM.  φ ← φ -η∇ φ L(f • φ) -ηλ g ∇ φ L g (g • φ) -ηλ e ψ∈{0,δ} ∇ φ R e (φ + ψ) f ← f -η∇ f L(f • φ) g ← g -η∇ g L g (g • φ) f e ← f e -η∇L e (f e • φ) fe ← fe -η∇L( fe • (φ + δ)) ∀e f -e ← f -e -η∇L -e (f -e • φ) ∀e B EXPERIMENTAL DETAILS B.1 MOLECULAR PROPERTY PREDICTION Data The four property prediction datasets are provided in the supplementary material, along with the training/validation/test splits. The size of each training environment, validation and test set are listed in Table 5 . The QM9, HIV, Tox21 and BBBP dataset are downloaded from Wu et al. (2018) . Model Hyperparameters For the feature extractor φ, we adopt the GCN implementation from Yang et al. (2019) . We use their default hyperparameters across all the datasets and baselines. Specifically, the GCN contains three convolution layers with hidden dimension 300. The predictor f is a two-layer MLP with hidden dimenion 300 and ReLU activation. The model is trained with Adam optimizer for 30 epochs with batch size 50 and learning rate η linearly annealed from 10 -3 to 10 -4 . For RGM, we explore λ ∈ {0.01, 0.1} for each dataset. For SRGM, we explore λ g ∈ {0.1, 1} for the classification datasets while λ g ∈ {0.01, 0.1} for the QM9 dataset as λ g = 1 causes gradient explosion.

Scaffold Classification

The scaffold classifier is trained by negative sampling since scaffolds are structured objects. Specifically, for each molecule x i in a minibatch B, the negative samples are the scaffolds {s k } of other molecules in the minibatch. The probability that x i is mapped to its correct scaffold s i is then defined as p(s i | x i , B) = exp{g(φ(x i )) g(φ(s i ))} k∈B exp{g(φ(x i )) g(φ(s k ))} The scaffold classification loss isi log p(s i | x i , B) for a minibatch B. We choose the classifier g to be a two-layer MLP with hidden dimension 300 and ReLU activation.

B.2 PROTEIN MODELING

Data The homology and stability dataset are downloaded from Rao et al. (2019) . The size of each training environment, validation and test set are listed in Table 5 . Model hyperparameters For both tasks, our protein encoder is a pre-trained BERT (Rao et al., 2019) . The predictor is a linear layer and the superfamily/topology classifier is a two-layer MLP whose hidden layer dimension is 768. The model is fine-tuned with an Adam optimizer with learning rate 10 -4 and linear warm up schedule. The batch size is 16 and 20 for the homology and stability task. For RGM and SRGM, we explore λ ∈ {0.01, 0.1} and λ g ∈ {0.1, 1} respectively. 

B.3 ADDITIONAL EXPERIMENTS

For the quantum chemistry dataset (QM9), prior work (Chen et al., 2019) has proposed to measure domain generalization via molecular size split. To show that our method also works well under this evaluation setup, we split the dataset based on the number of heavy atoms. The training set contains molecules with no more than 7 heavy atoms. The validation and test set consist of molecules with 8 and 9 heavy atoms respectively. This setup is much harder than random split as it requires models to extrapolate to new chemical space. 6 . Among the categorical methods, RGM outperforms all the baselines (except for property mu), with significant improvement on six properties (R2, Cv, U0, U, H, G) with 7-10% relative error reduction. SRGM outperforms all the baselines on eight properties (out of 12). While CrossGrad utilizes scaffold information, its performance is worse than RGM in general. Compared to RGM, SRGM shows significant error reduction (10-20%) on seven properties (alpha, R2, Cv, U0, U, H, G). This validates the advantage of exploiting structures of the environments (scaffolds). We further conduct additional experiments to study the performance of RGM/SRGM with respect to the severity of domain shift. Fixing the test set to molecules with 9 atoms, we construct three progressively harder training sets: molecules with no more than 8, 7 and 6 atoms. We report the MAE ratio (averaged over 12 properties) between SRGM/RGM/CrossGrad and ERM. As shown in Figure 6 , SRGM consistently outperforms CrossGrad and RGM across different setups.



Our method is introduced using scaffolds as examples. It can be applied to other structured environments like protein families by simply replacing the scaffold classifier with a protein family classifier. CONCLUSIONIn this paper, we propose regret minimization for generalization across structured biomedical domains such as molecular scaffolds or protein families. We seek to find a representation that enables the predictor to compete against an oracle with hindsight access to unseen domains. Our method significantly outperforms all baselines on real-world biomedical tasks.



Figure 3: Illustration of the SCOP hierarchy modified from Hubbard et al. [39].

Figure 3: a) Structured RGM: we introduce additional oracle predictors fe for the perturbed inputs; b) In molecule tasks, the scaffold classifier g is trained by negative sampling. Algorithm 1 Structured RGM: Forward Pass 1: for each environment E e ∈ E do 2: Sample a minibatch B e from environment E e 3: Compute scaffold classification loss L g (g • φ) over B e .

Figure 4: In the RGM forward pass, we sample a minibatch B e from each environment E e and compute regret R e (φ).In the backward pass, the gradient of L e (f e • φ) goes through a gradient reversal layer(Ganin et al., 2016) which negates the gradient during back-propagation.

-e (f -e • φ) and min fe L e (f e • φ) ∀e

Figure 5: Examples of scaffolds in the QM9 dataset (highlighted in grey). We split the data based on scaffold complexity. Thus, the test scaffolds are structurally distinct from scaffolds in the training set. As shown in the right figure, the molecular weight distribution of training, validation and test sets are similar. This shows that scaffold complexity split is more realistic than molecular weight split.

PROOF OF PROPOSITION 1 Note that L e (f • φ) is defined on a set of fixed examples in E e . Since f e ∈ arg min f ∈F L e (f • φ) and f e , f -e are in the same parametric family F, we have R e (φ) = L e (f -e • φ) -L e (f e • φ) ≥ 0. A.2 PROOF OF PROPOSITION 2 Proof. Consider any representation φ * ∈ Φ RGM . When there are only two environments {E 1 , E 2 }, we have F -2 (φ * ) = F 1 (φ * ) and F -1 (φ * ) = F 2 (φ * ) by definition. Thus the RGM constraint implies

Since the loss function is non-negative and F is bounded and closed,F 1 (φ * ) = ∅. Thus, ∩ e F e (φ * ) = F 1 (φ * ) = ∅. Now consider any f ∈ ∩ e F e (φ * ). By definition, ∀e : L e (f • φ * ) ≤ min h∈F L e (h • φ * ) By summing the above inequality over all environments, we have e L e (f • φ * ) ≤ e min h∈F L e (h • φ * ) ≤ min h∈F e L e (h • φ * ) Since e L e (f • φ

Figure 6: QM9 ablation study

Sequence length statistics for the SCOPe datasets. We report the mean and standard deviation along with minimum and maximum sequence lengths.

Comparison of encoder architectures on the ASTRAL 2.06 test set. Encoders included LM inputs and were trained using SSA without contact prediction. Figure 3: Illustration of the SCOP hierarchy modified from Hubbard et al. [39].

Sequence length statistics for the SCOPe datasets. We report the mean and standard deviation along with minimum and maximum sequence lengths.

Comparison of encoder architectures on the ASTRAL 2.06 test set. Encoders included LM inputs and were trained using SSA without contact prediction.

Sequence length statistics for the SCOPe datasets. We report the mean and standard deviation along with minimum and maximum sequence lengths.

Comparison of encoder architectures on the ASTRAL 2.06 test set. Encoders included LM inputs and were trained using SSA without contact prediction.

Sequence length statistics for the SCOPe datasets. We report the mean and standard deviation along with minimum and maximum sequence lengths.

Comparison of encoder architectures on the ASTRAL 2.06 test set. Encoders included LM inputs and were trained using SSA without contact prediction. Illustration of the SCOP hierarchy modified from Hubbard et al. [39].

Sequence length statistics for the SCOPe datasets. We report the mean and standard deviation along with minimum and maximum sequence lengths.

Comparison of encoder architectures on the ASTRAL 2.06 test set. Encoders included LM inputs and were trained using SSA without contact prediction.

Sequence length statistics for the SCOPe datasets. We report the mean and standard deviation along with minimum and maximum sequence lengths.

Mean absolute error (MAE) on the QM9 dataset. Models are evaluated under scaffold split. Due to space limit, we only show standard deviation for the top three methods in subscripts.

Left: Results on molecule and protein datasets. CrossGrad and SRGM use the structure of environments (scaffolds or protein superfamily) while others do not. Right: Ablation study of SRGM. Detach= means we do not update φ to optimize the scaffold (or protein superfamily) classification loss L g . Acc S stands for the scaffold/protein superfamily classification accuracy. Property is the property prediction performance (AUROC for molecules, top-1 accuracy for protein).

Comparison between SRGM with different perturbations on the QM9 dataset. "Scaffold" means perturbation via the gradient of the scaffold classifier. "Random" means random perturbation.

SRGM performance on the QM9 dataset with different number of graph convolutional layers in φ. Adding more layers increases model complexity.Results FollowingRao et al. (2019), we report the top-1 accuracy for homology prediction. Our ERM baseline matches their transformer model performance. As shown in Table2, both RGM and SRGM outperforms all the baselines (23.8% v.s. 22.3%). The difference between RGM and SRGM is relatively small due to inaccurate superfamily classifier. The top-1 and top-10 superfamily classification accuracy is around 33.5% and 51.0%. Nevertheless, SRGM can still give performance improvement because the gradient perturbation is computed based on the ground truth superfamily label during training. This teacher forcing step helps SRGM to be robust to superfamily variability despite the inaccurate superfamily classifier.

an optimal solution under IRM and φ * ∈ Φ IRM .

Since fe and φ optimizes L( fe • φ, Ẽe ) in different directions, we also introduce a gradient reversal layer between φ and fe . The SRGM update rule is the following:

Mean absolute error (MAE) on the QM9 dataset under molecular size split. Models are trained on molecules with no more than 7 atoms and tested on molecules with 9 atoms. Due to space limit, we only show standard deviation for the top three methods in subscripts.

SRGM performance under different molecular size split.

