CONTINUAL INVARIANT RISK MINIMIZATION

Abstract

Empirical risk minimization can lead to poor generalization behaviour on unseen environments if the learned model does not capture invariant feature representations. Invariant risk minimization (IRM) is a recent proposal for discovering environment-invariant representations. It was introduced by Arjovsky et al. ( 2019) and extended by Ahuja et al. (2020) . The assumption of IRM is that all environments are available to the learning system at the same time. With this work, we generalize the concept of IRM to scenarios where environments are observed sequentially. We show that existing approaches, including those designed for continual learning, fail to identify the invariant features and models across sequentially presented environments. We extend IRM under a variational Bayesian and bilevel framework, creating a general approach to continual invariant risk minimization. We also describe a strategy to solve the optimization problems using a variant of the alternating direction method of multiplier (ADMM). We show empirically using multiple datasets and with multiple sequential environments that the proposed methods outperforms or is competitive with prior approaches.

1. INTRODUCTION

Empirical risk minimization (ERM) is the predominant principle for designing machine learning models. In numerous application domains, however, the test data distribution can differ from the training data distribution. For instance, at test time, the same task might be observed in a different environment. Neural networks trained by minimizing ERM objectives over the training distribution tend to generalize poorly in these situations. Improving generalization of learning systems has become a major research topic in recent years, with many different threads of research including, but not limited to, robust optimization (e.g., Hoffman et al. (2018) ) and domain adaptation (e.g., Johansson et al. (2019) ). Both of these research directions, however, have their own intrinsic limitations (Ahuja et al. (2020) ). Recently, there have been proposals of approaches that learn environment-invariant representations. The motivating idea is that the behavior of a model being invariant across environments makes it more likely that the model has captured a causal relationship between features and prediction targets. This in turn should lead to a better generalization behavior. Invariant risk minimization (IRM, Arjovsky et al. (2019) ), which pioneered this idea, introduces a new optimization loss function to identify non-spurious causal feature-target interactions. Invariant risk minimization games (IRMG, Ahuja et al. (2020) ) expands on IRM from a game-theoretic perspective. The assumption of IRM and its extensions, however, is that all environments are available to the learning system at the same time, which is unrealistic in numerous applications. A learning agent experiences environments often sequentially and not concurrently. For instance, in a federated learning scenario with patient medical records, each hospital's (environment) data might be used to train a shared machine learning model which receives the data from these environments in a sequential manner. The model might then be applied to data from an additional hospital (environment) that was not available at training time. Unfortunately, both IRM and IRMG are incompatible with such a continual learning setup in which the learner receives training data from environments presented in a sequential manner. As already noted by Javed et al. (2020) , "IRM Arjovsky et al. (2019) requires sampling data from multiple environments simultaneously for computing a regularization term pertinent to its learning objective, where different environments are defined by intervening on one or more variables of the world." The same applies to IRMG (Ahuja et al. (2020)) To address the problem of learning environment-invariant ML models in sequential environements, we make the following contributions: • We expand both IRM and IRMG under a Bayesian variational framework and develop novel objectives (for the discovery of invariant models) in two scenarios: (1) the standard multienvironment scenario where the learner receives training data from all environments at the same time; and (2) the scenario where data from each environment arrives in a sequential manner. • We demonstrate that the resulting bilevel problem objectives have an alternative formulation, which allows us to compute a solution efficiently using the alternating direction method of multipliers (ADMM). • We compare our method to ERM, IRM, IRMG, and various continual learning methods (EWC, GEM, MER, VCL) on a diverse set of tasks, demonstrating comparable or superior performance in most situations.

2. BACKGROUND: OFFLINE INVARIANT RISK MINIMIZATION

We consider a multi-environment setting where, given a set of training environments E = {e 1 , e 2 , • • • , e m }, the goal is to find parameters θ that generalize well to unseen (test) environments. Each environment e has an associated training data set D e and a corresponding risk R e R e (w • φ) . = E (x,y)∼De e ((w • φ)(x), y), where  f θ = w • φ R ERM (θ) . = E (x,y)∼∪ e∈E De (f θ (x), y). ERM has strong theoretical foundations in the case of iid data (Vapnik (1992) ) but can fail dramatically when test environments differ significantly from training environments. To remove spurious features from the model, Invariant Risk Minimization (IRM, Arjovsky et al. (2019) ) instead aims to capture invariant representations φ such that the optimal classifier w given φ is the same across all training environments. This leads to the following multiple bi-level optimization problem min φ∈H φ ,w∈Hw e∈E R e (w • φ) s.t. w ∈ arg min we∈Hw R e (w e • φ), ∀e ∈ E, where H φ , H w are the hypothesis sets for, respectively, feature extractors and classifiers. Unfortunately, solving the IRM bi-level programming problem directly is difficult since solving the outer problem requires solving multiple dependent minimization problems jointly. We can, however, relax IRM to IRMv1 by fixing a scalar classifier and learning a representation φ such that the classifier is "approximately locally optimal" (Arjovsky et al. (2019) ) min φ∈H φ e∈E R e (φ) + λ||∇ w|w=1.0 R e (wφ)|| 2 , ∀e ∈ E, where w is a scalar evaluated in 1 and λ controls the strength of the penalty term on gradients on w. Alternatively, the recently proposed Invariant Risk Minimization Games (IRMG) (Ahuja et al. (2020) ) proposes to learn an ensemble of classifiers with each environment controlling one component of the ensemble. Intuitively, the environments play a game where each environment's action is to decide its contribution to the ensemble aiming to minimize its risk. Specifically, IRMG optimizes the following objective: min φ∈H φ e∈E R e ( w • φ) s.t. w e = arg min w∈Hw R e 1 |E| (w + w -e ) • φ , ∀e ∈ E, where w = 1 |E| e∈E w e is the average and w -e = e ∈E,e =e w e the complement classifier.

3. CONTINUAL IRM BY APPROXIMATE BAYESIAN INFERENCE

Both IRM and IRMG assume the availability of training data from all environments at the same time, which is impractical and unrealistic in numerous applications. A natural approach would be to combine principles from IRM and continual learning. Experience replay, that is, memorizing examples of past environments and reusing them later, could be possible in some scenarios but it is often difficult to estimate a-priori the extend of replay necessary to achieve satisfactory generalization capabilities. Here, we propose to adopt a probabilistic approach, exploiting the propagation of the model distribution over environments using Bayes' rule. We integrate both IRM and IRMG with stochastic models, introducing their variational counterparts that admit a continual extension. In addition, our approach is justified by the property of the Kullback-Leibler (KL) divergence that promotes invariant distributions when used in sequential learning (as shown in Theorem 3).

3.1. VARIATIONAL CONTINUAL LEARNING

Following prior work in continual learning (Nguyen et al. ( 2018)), let D t be the training data from the t-th environment e t , let D t 1 be the cumulative data up to the t-th environment, and let θ be the parameters of the feature extractor. When each environment is given in a sequential manner, we can use Bayes' rule and we have (all proofs are provided in the supplementary material) p(θ|D t 1 ) ∝ p(θ|D t-1 1 )p(D t |θ), that is, once we have the posterior distribution p(θ|D t-1 1 ) at time t -1, we can obtain, by applying Bayes rule, the posterior p(θ|D t 1 ) at time t up to a normalization constant. This is achieved by multiplying the previous posterior with the current data likelihood p(D t |θ). The posterior distribution is in general not tractable and we use an approximation. With the variational approximation, p(θ|D t 1 ) ≈ q t (θ), it is thus possible to propagate the variational distribution from one environment to the next. From Corollary 14 (in the supplementary material) we can write the continual variational Bayesian inference objective as q t (θ) = arg min q(θ) E (x,y)∼Dt E θ∼q(θ) { (y, f θ (x))} + D KL (q(θ)||q t-1 (θ)), from the variational distribution at step q t-1 (θ), with f θ = w • φ, a function with parameters θ.

3.2. EQUIVALENT FORMULATION OF IRM AS A BILEVEL OPTIMIZATION PROBLEM (BIRM)

In order to extend the IRM principle of Equation 3 using the principle of approximate Bayesian inference, by applying Lemma 5 (in supplementary material), we first introduce the following new equivalent definition of IRM (equation 3). Definition 1 (Bilevel IRM (BIRM)). Let H φ be a set of feature extractors and let H w be the set of possible classifiers. An invariant predictor w • φ on a set of environments E is said to satisfy the Invariant Risk Minimization (IRM) property, if it is the solution to the following bi-level Invariant Risk Minimization (BIRM) problem min φ∈H φ ,w∈Hw e∈E R e (w • φ) (8a) s.t. ∇ w R e (w • φ) = 0, ∀e ∈ E. This formulation results from substituting the minimization conditions in the constraint set of the original IRM formulation with the Karush-Kuhn-Tucker (KKT) optimality conditions. This new formulation allows us to introduce efficient solution methods and simplifies the conditions of IRM. It also justifies the IRMv1 model; indeed, when the classifier is a scalar value and the equality constraint is included in the optimization cost function, we obtain Equation 4. To solve the BIRM problem, we propose to use the Alternating Direction Method of Multipliers (ADMM) (Boyd et al. (2011)) . ADMM is an alternate optimization procedure that improves convergence and exploits the decomposability of the objective function and constraints. Details of the BIRM-ADMM algorithm are presented in the supplementary material.

3.3. BILEVEL VARIATIONAL IRM

At this point, we cannot yet directly extend the IRM principle using variational inference. That is because if we observe all environments at the same time, the prior of the single environment is data independent. Therefore, we substitute q t-1 (θ) from Equation 7with priors p φ (θ) and p w (ω), where θ and ω are now the parameters of the two functions φ and w. We also substitute q t (θ) with the variational distributions q φ (θ) and q w (ω). Definition 2 (Bilevel Variational IRM (BVIRM)). Let P φ be a family of distributions over feature extractors, and let P w be a family of distributions over classifiers. A variational invariant predictor on a set of environments E is said to satisfy Bilevel Variational Invariant Risk Minimization (BVIRM) if it is the solution to the following problem: min q φ ∈P φ qw∈Pw e∈E Q e φ (q w , q φ ) (9a) s.t. ∇ qw Q e w (q w , q φ ) = 0, ∀e ∈ E, with Q e φ (q w , q φ ) = E w∼qw φ∼q φ R e (w • φ) + βD KL (q φ ||p φ )+βD KL (q w ||p w ), and Q e w (q w , q φ ) = E w∼qw φ∼q φ R e (w • φ) + βD KL (q w ||p w ), and where p φ and p w are the priors of the two distributions. β is a hyper-parameter balancing the ERM and closeness to the prior. Definition 2 extends Definition 1 with the objective of Eq.7, where the parameters φ and w are substituted by their distributions q φ and q w . The gradient of the cost in the inner problem is taken with respect to the distribution q w . When we parameterize q φ with θ and q w with ω, the gradient is evaluated with respect to these parametersfoot_0 , since the condition implies that the solution is locally optimal. If Q(p, q) is convex in the first argument, then the solution is globally optimal. This definition extends the IRM principle to the case where we use approximate Bayes inference, shaping the variational distributions q w and q φ , to be, in expectation, invariant and optimal across environments.

3.4. THE BVIRM ADMM ALGORITHM

As noted for the BIRM definition, the solution of the variational BVIRM formulation can be obtained by using ADMM (Boyd et al. (2011) ). While in general there are no convergence results of ADMM methods for this problem, for local minima, under proper conditionsfoot_1 , the stochastic version of ADMM converges with rate O(1/ √ t) for convex functions and O(log t/t) for strongly convex functions (Ouyang et al. (2013) ). We are now in the position to write the BVIRM-ADMM formulation of the BVIRM problem. ADMM is defined by the update Eq.11, where we denote with the apexes -and + the value of any variable before and after the update. Moreover, we abbreviate as follows Q(ω, θ) = Q(q(ω), q(θ)). ω + e = arg min ωe L ρ (ω e , u - e , ω -, v - e ), ∀e ∈ E, ω + = 1/|E| e (ω e + u e ) u + e = u - e + (ω + e -ω + ) (11c) v + e = v - e + ∇ q(ω) Q e w (ω + e , θ) with L ρ (w e , u e , w, v e ) = Q e φ (ω e , θ) + ρ 0 2 ω e -ω + u e 2 + ρ 1 2 ∇ q(ω) Q e w (ω e • φ) + v e 2 . ( ) Here, φ is fixed and θ is updated in an external loop or given (e.g. the identity function). In the experiment we use stochastic Gradient Descent (SGD) to update both the model parameters w e and the feature extractor parameters φ. The result follows by applying Lemma 11 in the supplementary material and substituting x i ← we φ , f i (x i ) ← Q e φ (ω e , θ) and g i (x i ) ← ∇ q(ω) Q e w (ω + e , θ) . We provide a pseudo-code implementation leveraging Equation 11 as Algorithm 1. One of the advantages of the ADMM formulation of BVIRM of Eq.11, is that it can be computed in parallel, where only Eq.11b requires synchronization among environments, while the other steps can be computed independently. 

3.4.1. THE CONTINUAL BVIRM ADMM ALGORITHM

In presence of sequential environments, the priors for the new environment are given by the previous environment's distributions q - φ and q - w , this is obtained by comparing the BVIRM definition in Eqs. ( 9) with the continual Bayesian learning Equation ( 7). In Equation 10we thus now have Q e φ (q w , q φ ) = Ew∼q w ,φ∼q φ R e (w • φ) + βD KL (q φ ||q - φ )+βD KL (q w ||q - w ) and Q e w (q w , q φ ) = Ew∼q w ,φ∼q φ R e (w • φ) + βD KL (q w ||q - w ) Algorithm 2 presents an example implementation of ADMMfoot_2 applied to the continual BVIRM formulation.

3.5. INFORMATION-THEORETIC INTERPRETATION OF C-BVIRM

! " ! # ! $ Figure 1: Sequential pro- jection of distributions p 1 , . . . p t , where p i+1 = arg min p∈Pi+1 D KL (p||p i ) The KL divergence provides an additional motivation for the methods we propose. Indeed, for causal discovery Peters et al. (2015) suggests a discovery mechanism for causal variables as the intersection of the invariant conditional distributions across environments subject to interventions. The KL divergence is asymmetric and only components present in the first argument distribution are evaluated. This implies that by using the KL divergence we can compute the intersection of the distributions, even when these are observed sequentially. This can be made more explicit by the property of the information projection Cover (1999) Theorem 3 (Information Projection). If P and Q are two families of distributions with partially overlapping support, ∅ ⊂ supp(P ) supp(Q), and q ∈ Q, then p * = arg min p∈P D KL (p||q) has support in the intersection for the support of P and q, or supp(p * ) ⊆ supp(P ) supp(q). Therefore, if we have a sequence of sets of distributions of models from intervention environments and we compute the projection in sequence, the final projected distribution has support on the intersection of all previous distribution families, or supp(P t ) = t i=1 supp(P i ) (see Figure 1 ) , since at each step p i+1 = arg min p∈Pi+1 D KL (p||p i ) .

4. RELATED WORK

Generalization Domain adaptation (Ben-David et al., 2007; Johansson et al., 2019) aims to learn invariant features or components φ(x) that have similar P (φ(x)) on different (but related) domains by explicitly minimizing a distribution discrepancy measure, such as the Maximum Mean Discrepancy (MMD) (Gretton et al., 2012) or the Correlation Alignment (CORAL) (Sun & Saenko, 2016) . The above condition, however, is not sufficient to guarantee successful generalization to unseen domains, even when the class-conditional distributions of all covariates changes between source and target domains (Gong et al., 2016; Zhao et al., 2019) . Robust optimization (Hoffman et al., 2018; Lee & Raginsky, 2018) , on the other hand, minimizes the worst performance over a set of possible environments E, that is, max e∈E R e (θ). This approach usually poses strong constraint on the closeness between training and test distributions (Bagnell, 2005) which is often violated in practical settings (Arjovsky et al., 2019; Ahuja et al., 2020) . Incorporating the machinery of causality into learning models is a recent trend for improving generalization. (Bengio et al., 2019) argued that causal models can adapt to sparse distributional changes quickly and proposed a meta-learning objective that optimizes for fast adaptation. IRM, on the other hand, presents an optimization-based formulation to find non-spurious actual causal factors to target y. Extensions of IRM include IRMG and the Risk Extrapolation (REx) (Krueger et al., 2020) . Our work's motivation is similar to that of online causal learning (Javed et al., 2020) , which models the expected value of target y given each feature as a Markov decision process (MDP) and identifies the spurious feature x i if E[y|x i ] is not consistent to temporally distant parts of the MDP. The learning is implemented with a gating model and behaves as a feature selection mechanism and, therefore, can be seen as learning the support of the invariant model. The proposed solution, however, is only applicable to binary features and assumes that the aspect of the spurious variables is known (e.g. the color). It also requires careful hyper-parameter tuning. In the cases where data is not divided into environments, Environment Inference for Invariant Learning (EIIL) classification method (Creager et al. ( 2020)) aims at splitting the samples into environments. This method proves to be effective also when the environment label is present. Kirkpatrick et al. (2017); De Lange et al. (2019) addresses the problem of learning one classifier that performs well across multiple tasks given in a sequential manner. The focus is on the avoidance of catastrophic forgetting. With our work, we shift the focus of continual learning to the study of a single task that is observed in different environments.

5.1. DATASETS AND EXPERIMENT SETUP

Colored MNIST Figure 2 (left) shows a sample of train (upper) and test (lower) samples. In each training environment, the task is to classify whether the digit is, respectively, even or odd. As in prior work, we add noise to the preliminary label by randomly flipping it with a probability Table 1: Mean accuracy (N = 5) on train and test environments when training on 2 consecutive environments on MNIST and the b01 color correlation. of 0.25. The color of the image is defined by the variable z, which is the noisy label flipped with probability p c ∈ [0.1, 0.2]. The color of the digit is green if z is even and red if z is odd. Each train environment contains 30, 000 images of size 28 × 28 pixels, while the test environment contains 10, 000 images where the probability p c = 0.9. The color of the digit (b01) or the background (b11) is thus generated from the label but depends on the environment. Figure 3 shows the Fashion-MNIST dataset, where the variable z defines the background color. Again, we add noise to the preliminary label (y = 0 for "t-shirt", "pullover", "coat", "shirt", "bag" and y = 1 for "trouser", "dress", "sandal", "sneaker", "ankle boots") by flipping it with 25 percent probability to construct the final label. Besides, we also consider Kuzushiji-MNIST dataset Clanuwat et al. (2018) foot_3 and the EMNIST Letters dataset Cohen et al. (2017) foot_4 . The former includes 10 symbols of Hiragana, whereas the latter contains 26 letters in the modern English alphabet. For EMNIST, there are 62, 400 training samples per environment and 20, 300 test samples. We set y = 0 for letters 'a', 'c', 'e', 'g', 'i', 'k', 'm', 'o', 'q', 's', 'u', 'v', 'y' and y = 1 for remaining ones. C -B V I R M C -V I R M G C -V I R M v 1 E R M E W C G E M I R M G I R M v 1 M E R V C L V C L C Reference Methods. We compare with a set of popular reference methods in order to show the advantage of the variational Bayesian framework in learning invariant models in the sequential environment setup. For completeness, we also evaluate the performances of four reference continual learning methods. 23.9 43.3 9.9 12.5 12.5 9.8 9.9 11.5 24.9 24.9 (4.2) (6.9) (4.5) (0.2) (3.6) (3.6) (0.1) (0.2) (2.0) (2.0) (2.0) C -B V I R M C -V I R M G C -V I R M v 1 E R M E W C G E M I R M G I R M v 1 M E R V C L V C L C n 2

5.2. RESULTS

Table 1 lists the training and test accuracy on the MNIST dataset with the color correction b01 (see Figure 2 left). Since we introduced label noise by randomly flipping 25 percent of the given labels, a hypothetical optimal classifier would be able to achieve an accuracy of 75% in both training and test environments. ERM, IRMv1, and IRMG perform poorly in the setup where environments are given sequentially. Similarly, reference continual learning methods also fail to learn invariant representation in the new environment. As these models are learning to mainly use spurious features for the classification problems at hand, here: the colors of the digits (red∼odd; green∼even), they perform poorly (much worse than a random baseline) when the spurious feature properties are inverted (green∼odd; red∼even). In contrast, our variational extensions to both IRM and IRMG achieve a classification accuracy higher than 45% on the test data. This implies that our model is not relying exclusively on spurious correlations present in the color of digits. By comparing the performance between C-VIRMv1 and C-BVIRM, we conclude that (1) our proposed bilevel invariant risk minimization framework (i.e., the BIRM in Definition 1) is an effective alternative to the original formulation Arjovsky et al. (2019) ; and (2) ADMM is effective in solving the BIRM optimization problem and has the potential to improve the generalization performance. In addition, one can observe that the KL divergence term in VCL and our framework significantly improves the test accuracy with respect to the baseline counterparts. This result further justifies our motivation of using a variational Bayesian framework for the problem of continual invariant risk minimization. Table 2 lists the accuracy on the test environment for: (n) (upper rows) an increasing number of sequential environments (d) (central rows) different datasets, and (c) (lower rows) the two given color correlation schemes. We can observe that there is a general trend in the results. IRMG and IRM, with an accuracy of less than 10%, are not able to learn invariant models. Similarily, the continual learning reference methods (MER, EWC, MER, VCL, VCLC) also fail with a test accuracy of under 25%. The proposed methods on the other hand provide mechanism to learn more robust features and classification models. The higher variance of the accuracy is caused by the stochastic nature of the variational Bayesian formulation.

5.3. ENVIRONMENT INFERENCE FOR CONTINUAL INVARIANT LEARNING

In practical applications, the environmental labels are usually unavailable, which means that it is difficult or impossible to manually partitioning the training set into "domains" or "environments". 1'000 93.7 (0.7) 13.5 (1.7) 94.1 (1.1) 13.7 (1.5) 95.5 (0.3) 12.7 (2.0) 96.0 (0.4) 16.5 (6.0) 2'000 91.5 (0.4) 12.7 (0.9) 91.1 (0.7) 11.9 (0.9) 92.6 (0.4) 27.8 (2.8) 93.3 (0.7) 29.3 (3.4) 5'000 90.2 (0.4) 10.5 (1.1) 90.1 (0.4) 10.6 (0.7) 91.6 (0.4) 29.6 (4.8) 91.6 (0.9) 30.6 (3.2) 10'000 89.9 (0.3) 10.1 (0.5) 90.0 (0.2) 10.1 (0.1) 85.3 (1.0) 42.9 (3.9) 83.7 (1.2) 50.4 (2.3) 20'000 90.0 (0.2) 10.1 (0.2) 90.1 (0.2) 10.1 (0.0) 77.2 (1.2) 57.4 (2.2) 77.9 (1.1) 57.6 (2.0) 50'000 90.1 (0.1) 9.7 (0.4) 90.0 (0.1) 10.0 (0.4) 73.9 (0.5) 67.2 (1.2) 74.0 (0.5) 67.3 (1.0) In order to generalize our continual invariant learning models to an environment-agnostic setting, we leverage the recently proposed Environment Inference for Invariant Learning (EIIL) by Creager et al. (2020) to automatically infer environment partitions from observational training data, and integrate EIIL into our continual invariant learning models. We take our proposed C-VIRMv1 as an example. According to Table 3 , it is easy to observe that inferring environments directly from observational data (using EIIL) has the potential to improve (continual) invariant learning relative to using the hand-crafted environments. Moreover, C-VIRMv1 with EIIL improves both training and test accuracy, compared with IRMv1 with EIIL. In fact, this environment partition strategy also enables invariant learning with only one environmental data. Table 4 further suggests that the generalization accuracy improves for both IRMv1 and C-VIRMv1 as the number of training samples increases. Again, we observed that, when combined with EIIL, C-VIRMv1 always outperforms IRMv1. A SUPPLEMENTARY MATERIAL

A.1 VARIATIONAL INVARIANT RISK MINIMIZATION GAMES

We now consider the IRMG objective and extend it with the variational Bayesian inference. If we observe all environment at the same time, the prior of the single environment is data independent. From Equation 7, we thus substitute q t-1 (θ) with a priors p φ (θ) and q w (ω), where θ and ω are now the parameters of the two functions φ and w. While we substitute q t (θ), with the variational distributions q φ (θ) and q w (ω). The outer problem is now min q φ E φ∼q(φ) R e ( w • φ) + βD KL (q φ ||p φ ) (13a) s.t. q we = arg min qw e E w∼qw e R e ( 1 |E| (w + w -e ) • φ) + βD KL (q we ||p w )∀e ∈ E tr (13b) where w = 1 |E| e∈E tr w e , w e ∼ q we (w) is the average classifier and w -e = e ∈E tr ,e =e w e , w e ∼ q w e (w) is the complement classifier. In the reformulation of the IRMG model, we weight the distance of the varional distribution to the prior with β. We notice how the difference of the variational formulation of IRMG differs on the presence of the mean on the distribution of the function over the variational distributions and the KL term. We can now finally extend IRMG when the environments are observed sequentially. Combining the definition of IRMG Eqs. ( 5) with the continual bayesian learning Equation ( 7), we obtain the variational objective of IRMG in sequential environment case. min q φ E φ∼q(φ) { (y, w • φ)} + βD KL (q φ ||q t-1 φ ) (14a) s.t. w = 1 2 (w + w t-1 ), w ∼ q w (w), w t-1 ∼ q t-1 w (w) q w = arg min qw E w∼qw e ,φ∼q φ { (y, 1 2 (w + w t-1 ) • φ} + βD KL (q we ||q t-1 w ) (14c) We can similarly extend the definition of IRMv1 when all environments are seen at the same time and sequentially.

A.2 MEAN FIELD PARAMETRIZATION AND REPARAMETRIZATION TRICK

When we want to implement Equation 11 and Equation 12 and the different variation, we use the mean field approximation and the reparametrization trick Kingma & Welling (2013) . In this case the density function of our model is parameterized by θ and ω and constraints becomes ∇ q(ω) Q e w (ω + e , θ) = 0 → ∇ ω Q e w (ω, θ) = 0. If we then parametrize µ(ω µ ) and σ(ω σ ) the mean and standard deviation and model the distribution as q ω (w) = µ(ω µ ) + σ(ω σ ), with ∼ N (0, 1) We now want to compute the gradient (in the following we ignore the dependence on the φ and its parameters) ∇ ω Q(ω) = ∇ ω E w∼q(ω) R(w • φ) + β∇ ω D KL (q(ω)||p) The second term is ∇ ω D KL (q||p) = ∇ µ D KL (q||p)∇ ω µ + ∇ σ D KL (q||p)∇ ω σ with ∇ ω µ = 1, ∇ ω σ = 1 ∇ µ D KL (q||p) = -σ -1 p (µ p -µ q ) ∇ σ D KL (q||p) = -diag(σ q ) -1 + diag(σ p ) -1 where we assume σ p , σ q to be diagonal, in this way the previous equation can be evaluated elementwise and where the D KL (q||p) is defined as D KL (q||p) = ln |Σ p | |Σ q | -n + tr Σ -1 p Σ q + (µ p -µ q ) T Σ -1 p (µ p -µ q ) The first term is evaluated by Monte Carlo sampling ∇ ω E w∼q(ω) R(w) ≈ ∇ ω 1 N N i=1 R(w i ) with w i = µ(ω) + i σ(ω) and w i ∼ N (0, 1). Also in this case ∇ ω 1 N N i=1 R(w i ) = ∇ µ 1 N N i=1 R(w i )∇ ω µ + ∇ σ 1 N N i=1 R(w i )∇ ω σ A.3 THE BIRM-ADMM ALGORITHM We observe that to solve BIRM we can use Lemma 11 and write the following algorithm w + e = arg min we L ρ (w e , u - e , w -, v - e ), ∀e ∈ E (15a) w + = 1/|E| i (w e + u e ) u + e = u - e + (w + e -w + ) (15c) v + e = v - e + ∇ w R e (w + e • φ) where L ρ (w e , u e , w, v e ) = R e (w e • φ) + ρ 0 2 w e -w + u e 2 + ρ 1 2 ∇ w R e (w e • φ) + v e 2 We denote . + , . -the values of the variable after and before the update. In order to implement the method we use the SGD to update the model w e and in a outer loop updating for φ. 

A.4 VARIATIONAL INVARIANT RISK MINIMIZATION

Definition 4 (VIRM). Give a set of distribution over the mapping P φ and a distribution over the set of classifier P w , a variational invariant predictor on a set of environments E is said to satisfy the Variational Invariant Risk Minimization (VIRM) if it is the solution of the following problem min q φ ∈P φ qw∈Pw e∈E Q e φ (q w , q φ ) (17a) s.t. q w ∈ arg min q e w ∈Pw Q e w (q w , q φ ), ∀e ∈ E (17b) where Q e φ (q w , q φ ) = E w∼qw φ∼q φ R e (w • φ) + βD KL (q φ ||p φ )+βD KL (q w ||p w ), Q e w (q w , q φ ) = E w∼qw φ∼q φ R e (w • φ) + βD KL (q w ||p w ) and p φ , p w are the priors of the two distributions.

A.5 BILEVEL ALTERNATIVE FORMULATION

We state here a general result on solving Bilevel Optimization Problems Lemma 5 (Bilevel Reformulation). min x F (x, y)|G(x, y(x)) ≤ 0 (18a) s.t. y(x) ∈ arg min y f (x, y)|g(x, y) ≤ 0 (18b) then we can solve the equivalent problem min x,y,u F (x, y)|G(x, y(x)) ≤ 0, ∇ y L(x, y, u) = 0, (19b) u ≥ 0, (19c) g(x, y) ≤ 0, u T g(x, y) = 0 (19e) L(x, y, u) = f (x, y) + u T g(x, y) Proof of Lemma 5 . Lemma 5 follows by applying the Karush-Kuhn-Tucker conditions (Chapter 5 Boyd et al. (2004) ) to Eq.18, where the Lagrangian function is L(x, y, u) = f (x, y)+u T g(x, y). Lemma 6 (Equivalence of Definition 1). Definition 1 is equivalent to Eq. 3, the Invariant Risk Minimization. Proof of Lemma 6 . The result follows by apply Lemma 5 to Eq.3. Lemma 7 (Definition 2). Definition 2 is the extension of Eq. 8, the Bilevel Invariant Risk Minimization, when the function is described by the distributions of their variable φ and w. Proof of Lemma 7 . The result follows by inspecting Eq. 8. The equation requires the minimisation of the aggregated loss function, which is now, from Eq.7: Q e φ (q w , q φ ) = E w∼qw φ∼q φ R e (w • φ) + βD KL (q φ ||p φ ) + βD KL (q w ||p w ), where we have separated the two contributions in φ and w, and used genetic prior distributions p φ and p w . This is by the additive property of KL divergence: D KL (q φ q w ||p φ p w ) = D KL (q φ ||p φ ) + D KL (q w ||p w ), since we model the two distributions independently, i.e. q φ,w = q φ q w and q φ,w = p φ p w . Since the classifiers' losses shall be minimal for all environments, this condition is substituted by requiring the gradient with respect to q w to be zero, ∀e. The gradient w.r.t. q w of the second term of Eq.20 is zero. A We show now two ways to state the connection of the IRM principle and the sequential IP. Let q -be the distribution of the previous environment and R(q) the loss function of the current environment, where q denotes the distribution of the network parameters. Let q * = arg min q R(q) be the optimal distribution for the current environment. We can then consider the Taylor expansion of the parameters distribution around the optimal distribution as R(q) = R(q * ) + ∆q T ∇ q R(q * ) we can compute the new distribution as ∆q * = arg min ∆q D KL (q * + ∆q||q -) s.t. ∆q T ∇ q R(q * ) ≤ and then p + = q * + ∆q * Or alternatively p * , q * = arg min p,q D KL (p||q -) + D KL (p||q) s.t. ∇ q R(q) = 0 and then q + = p * Or more simply q + = arg min p D KL (p||q -) + D KL (p||q * ) This last equation, shows how the new distribution is the intersection of the optimal distribution at the previous step q -and the current optimal distribution q * . Fig 8 shows visually, how the new distribution is the result of projecting into two distributions q * and q -.

A.7 OUT OF DISTRIBUTION GENERALIZATION

The question arises if the property of generalization to out of distributions given by Theorem 9 in Arjovsky et al. (2019) also holds for BIRM and BVIRM. Lemma 9. If φ and w are linear functions and w • φ = Φ T w is a solution of Eq.8 it then satisfies ΦE X e X e X eT Φ T w = ΦE X e ,Y e X e Y eT Proof. Lemma 9 follows from the fact that ∇ w R e (w • φ) = ΦE X e X e X eT Φ T w -ΦE X e ,Y e X e Y eT (23) = 0 For BIRM thus Theorem 9 of Arjovsky et al. (2019) applies directly. A similar results holds for the the BVIRM model Lemma 10. If φ ∼ p φ and w ∼ p φ are linear functions and w • φ = Φ T w is a solution of Eq.9, with β = 0, it then satisfies E φ∼q φ ΦE X e X e X eT Φ T w = ΦE X e ,Y e X e Y eT where Φ = E Φ∼p φ [Φ] and w = E w∼pw [w] are the mean values. Proof. Lemma 10 follows from the fact that ∇ qw Q e w (q w , q φ )| β=0 = ∇ qw E w∼qw φ∼q φ R e (w • φ) (25) = 0 We now take the Fréchet directional derivative in the η direction that is the limit of δ qw,η E w∼qw φ∼q φ R e (w • φ) = lim →0 1 (E w∼qw φ∼q φ R e (((w + η) • φ) -E w∼qw φ∼q φ R e (w • φ)) which is obtained when we differentiate the distribution q w → q w + η. Since δ qw,η E w∼qw φ∼q φ R e (w • φ) = E w∼qw φ∼q φ 2η T ΦE X e X e X eT Φ T w -2η T ΦE X e ,Y e X e Y eT we can factorize for the direction η and obtain δ qw R e (w • φ) = 2 E w∼qw φ∼q φ ΦE X e X e X eT Φ T w -ΦE X e ,Y e X e Y eT We can now derive the Lemma by requiring δ qw R e (w • φ) = 0 Theorem 9 of Arjovsky et al. (2019) now holds when φ has rank r > 0 in expectation with respect to the invariant distribution q φ , i.e. E φ∼q φ rank(Φ) = r.

A.8 GENERALIZED ADMM

The following generalization of ADMM holds: Lemma 11 (GADMM). Suppose we want to minimized min x i f i (x)|g i (x) = 0, ∀i ∈ I (27) we can equivalently solve the following problem min xi,z i f i (x i )|x i = z, g i (x i ) = 0, ∀i ∈ I (28) using the following update role (scaled ADMM) This can be obtain by minimizing the KL divergence of the variational distribution q t and the distribution induced by the previous step approximation, thus q t (θ) = arg min q(θ) D KL q(θ)|| 1 Z t q t-1 (θ)p(D t |θ) Lemma 13 (VCLv2). The minimization of the VCL defined in Lemma 12, is equivalent to solve the following minimization q t (θ) = arg max q(θ) E θ∼q(θ) {log p(D t |θ)} -D KL (q(θ)||q t-1 (θ)) with N t i.i.d. samples E θ∼q(θ) {log p(D t |θ)} = 1 N t Nt i=1 E θ∼q(θ) {log p(y t i |θ, x t t )} Where the second term can be computed in closed form for known distribution as for example with the Gaussian distributions, whereas the expectation can be approximated by Monte Carlo sampling. For a general loss function we can substitute the reconstruction probability with the loss function associated with a neural network parametrized by θ log p(y t i |θ, x t t ) ← (y t i , (w • φ) θ (x t t )) E θ∼q(θ) {log p(D t |θ)} ← 1 N t Nt i=1 E θ∼q(θ) (y t i , (w • φ) θ (x t t ))} Proof of Lemma 13. The Lamma follows from the definition of the KL diveregnce D KL q(θ)|| 1 Z t q t-1 (θ)p(D t |θ) = E q (ln q(θ) -ln q t-1 (θ) -ln p(D t |θ) + ln Z t ) = E q (ln q(θ) -ln q t-1 (θ)) -E q ln p(D t |θ) + E q ln Z t = D KL (q(θ)||q t-1 (θ)) -E q ln p(D t |θ) + ln Z t The last term does not depend on q. Thus the result follows. If we substitute the log of the posterior probability with a specific loss function we obtain the following Corollary. Corollary 14 (Continual Variational Bayesian Inference). Given a loss function (y, ŷ), the variational continual learning is formulated as q t (θ) = arg min q(θ) E (x,y)∼Dt E θ∼q(θ) { (y, f θ (x))} + D KL (q(θ)||q t-1 (θ)), with f θ = (w • φ) θ A.10 PROOFS Proof of Theorem 3. Let first first recall that D KL (p||q) = p(x) ln p(x) q(x) dx. If q(x) = 0 then p(x) = 0 otherwise the distance is infinite. Second if p(x) = 0, then the contribution of q(x) is not considered since the integral is taken of the support of p, thus, since the intersection is not null and p is the result of an optimization, the support of p is the intersection of the support of q and the support of P .

A.11 DATASETS AND COLOR CORRECTION

We here visualize few of the dataset and color correlations. Figure 5 shows Fashion-MNIST and the b11 color correlation. In the test environment the background color of each class is inverted. In Figure 6 we show the dataset as generated from Ahuja et al. (2020) . In Figure 7 we show the EMINST (letter) and KMNIST dataset. A.12 HYPER-PARAMETER SEARCH AND EXPERIMENTAL SETUP We performed hyper-parameter search around the suggested values from the original works and the values selected based on the best performance on the test environment. To implement a complete comparison we used for training 1 000 samples randomly drawn from each environment. All methods were trained on the same data, using random seed reset. We trained all method with 100 epochs on a batch size of 256. • IRM: γ = 91257, threshold = 1/2 epochs, learning rate 2.5e -4 • IRMG: warm start=300, termination accuracy 0.6, learning rate 2.5e -4 , dropout probability 75%, weight decays = .00125 • ERM: learning rate 1e -3 , dropout probability 75%, weight decays = .00125 • MER: memory size 100 (10% of the samples), learning rate 1e -3 , replay batch size =5, β = .03, γ = 1.0 • GEM: memory size 100 (10% of the samples), learning rate 1e -3 • EWC: memory size 100 (10% of the samples), learning rate 1e -3 , regularization 0.1 • VCL,VCLC: learning rate = 5e -3 , corset size 100 (10% of the samples), • C-BVIRM, C-VIRMG, C-VIRMv1: weight decays = .00125, β = 1., number evaluations 5, ρ 0 = ρ 1 = 10, step threshold =1/2 epochs, δρ = 100, learning rate 1e -3 The neural network architecture is composed of 2 non-linear Exponential Linear Unit (ELU) activated Full Connected layers of size 100, followed by a linear full connected layer. A.13.1 SYNTHETIC DATASET The Synthetic Dataset is described in (Arjovsky et al. (2019) ) for testing IRM and it is defined by a Structural Causal Model (Pearl (2009) ), where a variable y ∈ R N is generated by x 1 ∈ R N , while x 2 ∈ R N is generated by y. The observed variable is x = (x 1 , x 2 ). The structural equations are x 1 = 1 , 1 ∼ N (0, σ 2 1 ) (32) y = x 1 + y , y ∼ N (0, σ 2 e ) (33) x 1 = y + 2 , 2 ∼ N (0, 1) with σ 1 fixed and σ e dependent on the environment. We compared with ERM, IRM (Arjovsky et al. (2019) ), IPC (Invariant Prediction), which is the method proposed in Peters et al. (2015) , and EIIL ( Creager et al. (2020) ). We use a similar set up of Creager et al. (2020) , with N = 4. The invariant model is given by w = (1, 0).



Implementation detail using the mean field parameterization and reparametrization trick is provided in the Supplementary Material These conditions are specific bounds on the magnitude and variance of the (sub-)gradients of the stochastic function(Ouyang et al. (2013)). We used ELU ∈ C ∞ in the experiments. In Algorithm 1 the ADMM update equation is implemented from line 6 to line 9, while in Algorithm 2, from line 7 to line 10. https://github.com/rois-codh/kmnist https://www.nist.gov/itl/products-and-services/emnist-dataset https://github.com/facebookresearch/GradientEpisodicMemory https://github.com/mattriemer/mer https://github.com/nvcuong/variational-continual-learning



w, φ ← BVIRM-ADMM(E, R e ) ADMM version of the Bilevel Variational IRM Algorithm Result: w • φ : feature extraction and classifier for the environment E // Randomly initialize the variables ω, ω e , u e , v e , θ ← Init() ; // Outer loop (on θ) and Inner loop (on ω) while not converged do // Update φ using SGDθ = SGD θ ( e∈E Q e φ (q w , q φ )) ; for k = 1, . . . ,K do for e ∈ E do 6 ω e = SGD ωe L ρ (ω e , u e , ω, v e ) ; 7 ω = 1/|E| e (ω e + u e ) ; 8 u e = u e + (ω e -ω) ; 9 v e = v e + ∇ ω Q e (ω e , θ) ; end end end Algorithm 2: w, φ ← C-BVIRM-ADMM(E, R e ) ADMM version of the Bilevel Variational IRM Algorithm Result: w ω • φ -θ : feature extraction and classifier for the environment E // Randomly initialize the variables 1 ω, ω e , u e , v e , θ ← Init() ; 2 ω = 0 ; 3 for e ∈ E do 4 for k = 1, . . . , K do 5 θ = SGD θ (Q e φ (q w , q φ )) ; 6 while not converged do // Update ω using SGD and ADMM 7 ω e = SGD ωe L ρ (ω e , u e , ω, v e ) ; 8 ω = 1/2(ω e + u e + ω) ; 9 u e = u e + (ω e -ω) ; 10 v e = v e + ∇ ω Q e (ω e , θ) ;

Figure 2: The two color models (on the left b01, on the right b11) for the train (upper row) and test (lower row) of the MNIST (left) and FashionMNIST (right) datasets.

Figure 3: Causal relationships of colored MNIST. Colored FashionMNIST, KMNIST, and EMNIST Figure 2 (right)shows the Fashion-MNIST dataset, where the variable z defines the background color. Again, we add noise to the preliminary label (y = 0 for "t-shirt", "pullover", "coat", "shirt", "bag" and y = 1 for "trouser", "dress", "sandal", "sneaker", "ankle boots") by flipping it with 25 percent probability to construct the final label. Besides, we also consider Kuzushiji-MNIST datasetClanuwat et al. (2018) 4 and the EMNIST Letters datasetCohen et al. (2017) 5 . The former includes 10 symbols of Hiragana, whereas the latter contains 26 letters in the modern English alphabet. For EMNIST, there are 62, 400 training samples per environment and 20, 300 test samples. We set y = 0 for letters 'a', 'c', 'e', 'g', 'i', 'k', 'm', 'o', 'q', 's', 'u', 'v', 'y' and y = 1 for remaining ones.

w, φ ← BIRM-ADMM(E, R e ) ADMM version of the Bilevel IRM Algorithm Result: w • φ : feature extraction and classifier for the environment E // Randomly initialize the variables w, w e , u e , v e , φ ← Init() ; // Outer (on φ) and Inner loop (on w) while not converged do // Update φ using stochastic gradient descent(SGD) φ = SGD φ ( e R e (w • φ)) ; for k = 1, . . . , K do for e ∈ E do 6 w e = SGD we L ρ (w e , u e , w, v e ) ; 7 w = 1/|E| e (w e + u e ) ; 8 u e = u e + (w e -w) ; 9 v e = v e + ∇ w R e (w e • φ) ; end end end

Figure 4: Sequential IRM projection of distributions, where q + = arg min D KL (p||q -) + D KL (p||q * )

Figure 5: Fashion MNIST dataset training (a) and testing (b) environments; the color is inverted based on the b11 color correlation scheme, where the background color depends on the class of the image. In the test environment the dependency is inverted.

Figure 6: MNIST dataset (a) and Fashion MNIST (b) environments as defined in Ahuja et al. (2020)

Nguyen et al. (2018))8 . ERM is the classical empirical risk minimization method; we always use the crossentropy loss. IRMv1 enforces the gradient of the model with respect to a scalar to be zero. IRMG models the problem as a game among environments, where each environment learns a separate model. EWC imposes a regularization cost on the parameters that are relevant to the previous task, where the relevance is measured by Fisher Information (FI); GEM uses episodic memory and computes the updates such that accuracy on previous tasks is not reduced, using gradients stored from previous tasks; MER uses an efficient replay memory and employs the meta-learning gradient update to obtain a smooth adaptation among tasks; VCL and Variational Continual Learining with coreset VCLC apply variational inference to continual learning. C-VIRMv1 and C-VIRMG refer to, respectively, our proposed variational extensions of IRMv1 and IRGM in sequential environments. C-BVIRM is the implementation with ADMM.

Mean accuracy (over 5 runs) and standard deviation at test time for (n) 2, 6, 10 environments, (d) across datasets, and (c) for the two color correlations (b01,b11).

Mean accuracy (over 10 runs) on train and test environments when training off-line on 2 environments on Colored-MNIST, with the EIIL. (pc 1 = 0.2, pc 2 = 0.1, 50 000 samples)

Mean accuracy (over 5 runs) on train and test environments when training on 1 environment on Colored-MNIST, with and without EIIL. (pc 1 = 0.1)

.6 THEOREM 3 AND IRM CONNECTION A.6.1 SEQUENTIAL INFORMATION PROJECTION In Theorem.3, we show that the Information Projection (IP) shrinks the support of the output distribution. If we have a sequence of families of distributions P i . Let p 1 ∈ P 1 and Proof of Lemma 8. We have that ∀i, supp p i ⊆ supp(P i ) supp(P i-1 ), where the first condition follows from p i ∈ P i in the minimization and the second from Theorem 3. The results follows by iterating the property.

Each layer with dropout. Dropout is not present in VCL/VCLC since not implemented in the original work. Training loss is the Cross Entropy. We tested also with the feature extraction layer separated, but with no advantage, since the test set-up only consist of one task.The IRMv1, IRMG and ERM methods, similarly to the other methods, are trained sequentially as data from each new environment arrives. The Continual Learning methods are allowed to have a limited memory of samples from previous environments.

Mean accuracy (over 5 runs) on Synthetic Dataset(Arjovsky et al. (2019),Creager et al. (2020)). BIRM refers to our bilevel objective Eq. (8) optimized with ADMM.

6. CONCLUSIONS

We aim to broaden the applicability of IRM to settings where environments are observed sequentially. We show that reference approaches fail in this scenario. We introduce a variational Bayesian approach for the estimation of the invariant models and a solution based on ADMM. We evaluate the proposed approach with reference models, including those from continual learning, and show a significant improvement in generalization capabilities.

annex

where the augmented Lagrangianand x -i = {x j , j = i} is the set of all other variable, expect the i-th.

A.9 CONTINUAL VARIATIONAL INFERENCE

Following Nguyen et al. ( 2018) we can state the following lemma.Lemma 12 (Variational Continual Learning). Suppose we have a sequence of datasets D i , i = 1, . . . , t drown i.i.d, then the variational estimation of the distribution q t at step t is given as projection on KL divergencewith Z t = q t-1 (θ)p(D t |θ)dθ the normalization factor, which does not depends on q.Proof of Lemma 12. Let denoteWe are interested to maximase the a posteriori probability of the paramters give the data p(θ|D t 1 )We now use a probability distribution which approximates the distribution at step t -1 q t-1 (θ) ≈ p(θ|D t-1 1 ) when then want to approximate at time t q t (θ) ≈ 1 p(D t ) q t-1 (θ)p(D t |θ)

