SYNERGIES BETWEEN DISENTANGLEMENT AND SPARSITY: A MULTI-TASK LEARNING PERSPECTIVE

Abstract

Although disentangled representations are often said to be beneficial for downstream tasks, current empirical and theoretical understanding is limited. In this work, we provide evidence that disentangled representations coupled with sparse base-predictors improve generalization. In the context of multi-task learning, we prove a new identifiability result that provides conditions under which maximally sparse base-predictors yield disentangled representations. Motivated by this theoretical result, we propose a practical approach to learn disentangled representations based on a sparsity-promoting bi-level optimization problem. Finally, we explore a meta-learning version of this algorithm based on group Lasso multiclass SVM base-predictors, for which we derive a tractable dual formulation. It obtains competitive results on standard few-shot classification benchmarks, while each task is using only a fraction of the learned representations.

1. INTRODUCTION

The recent literature on self-supervised learning has provided evidence that learning a representation on large corpuses of data can yield strong performances on a wide variety of downstream tasks (Devlin et al., 2018; Chen et al., 2020) , especially in few-shot learning scenarios where the training data for these tasks is limited (Brown et al., 2020b; Dosovitskiy et al., 2021; Radford et al., 2021) . Beyond transferring across multiple tasks, these learned representations also lead to improved robustness against distribution shifts (Wortsman et al., 2022) as well as stunning text-conditioned image generation (Ramesh et al., 2022) . However, preliminary assessments of the latter has highlighted shortcomings related to compositionality (Marcus et al., 2022) , suggesting new algorithmic innovations are needed to make further progress. Another line of work has argued for the integration of ideas from causality to make progress towards more robust and transferable machine learning systems (Pearl, 2019; Schölkopf, 2019; Goyal & Bengio, 2022) . Causal representation learning has emerged recently as a field aiming to define and learn representations suited for causal reasoning (Schölkopf et al., 2021) . This set of ideas is strongly related to learning disentangled representations (Bengio et al., 2013) . Informally, a representation is considered disentangled when its components are in one-to-one correspondence with natural and interpretable factors of variations, such as object positions, colors or shape. Although a plethora of works have investigated theoretically under which conditions disentanglement is possible (Hyvärinen & Morioka, 2016; 2017; Hyvärinen et al., 2019; Khemakhem et al., 2020a; Locatello et al., 2020a; Klindt et al., 2021; Von Kügelgen et al., 2021; Gresele et al., 2021; Lachapelle et al., 2022; Lippe et al., 2022b; Ahuja et al., 2022c) , fewer works have tackled how a disentangled representation could be beneficial for downstream tasks. Those who did mainly provide empirical rather than theoretical evidence for or against its usefulness (Locatello et al., 2019; van Steenkiste et al., 2019; Miladinović et al., 2019; Dittadi et al., 2021; Montero et al., 2021) . In this work, we explore synergies between disentanglement and sparse base-predictors in the context of multi-task learning. At the heart of our contributions is the assumption that only a small subset of all factors of variations are useful for each downstream task, and this subset might change from one task to another. We will refer to such tasks as sparse tasks, and their corresponding sets of useful factors as their supports. This assumption was initially suggested by Bengio et al. (2013, Section 3.5 ): "the feature set being trained may be destined to be used in multiple tasks that may have distinct [and unknown ] subsets of relevant features. Considerations such as these lead us to the conclusion that the most robust approach to feature learning is to disentangle as many factors as possible, discarding as little information about the data as is practical". This strategy is very much in line with the current self-supervised learning trend (Radford et al., 2021) , except for its focus on disentanglement. Our main contributions are the following: (i) We formalize this "sparse task assumption" and argue theoretically and empirically how, in this context, disentangled representations coupled with sparsity-regularized base-predictors can obtain better generalization than their entangled counterparts (Section 2.1). (ii) We introduce a novel identifiability result (Theorem 1) which shows how one can leverage multiple sparse tasks to learn a shared disentangled representation by regularizing the task-specific predictors to be maximally sparse (Section 2.2.1). Crucially, Assumption 7 formalizes how diverse the task supports have to be in order to guarantee disentanglement. (iii) Motivated by this result, we propose a tractable bi-level optimization (Problem (4)) to learn the shared representation while regularizing the task-specific base-predictors to be sparse (Section 2.2.2). We validate our theory by showing our approach can indeed disentangle latent factors on tasks constructed from the 3D Shapes dataset (Burgess & Kim, 2018) . (iv) Finally, we draw a connection between this bi-level optimization problem and some formulations from the meta-learning literature (Section 2.3). Inspired by our identifiability result, we enhance an existing method (Lee et al., 2019) , where the base-learners are now group-sparse SVMs. We show that this new meta-learning algorithm achieves competitive performance on the miniImageNet benchmark (Vinyals et al., 2016) , while only using a fraction of the learned representation.

2. SYNERGIES BETWEEN DISENTANGLEMENT AND SPARSITY

In this section, we formally introduce the notion of entangled and disentangled representations. First, we assume the existence of some ground-truth encoder function f θ : R d → R m that maps observations x ∈ X ⊆ R d , e.g., images, to its corresponding interpretable and usually lower dimensional representation f θ (x) ∈ R m , m ≤ d. The exact form of this ground-truth encoder depends on the task at hand, but also on what the machine learning practitioner considers as interpretable. The learned encoder function is denoted by f θ : R d → R m , and should not be conflated with the groundtruth representation f θ . For example, f θ can be parametrized by a neural network. Throughout, we are going to use the following definition of disentanglement. Definition 1 (Disentangled Representation, Khemakhem et al. 2020a; Lachapelle et al. 2022) . A learned encoder function f θ : R d → R m is said to be disentangled w.r.t. the ground-truth representation f θ when there exists an invertible diagonal matrix D and a permutation matrix P such that, for all x ∈ X , f θ (x) = DP f θ (x). Otherwise the encoder f θ is said to be entangled. Intuitively, a representation is disentangled when there is a one-to-one correspondence between its components and the components of the ground-truth representation, up to rescaling. Note that there exist less stringent notions of disentanglement which allow for component-wise nonlinear invertible transformations of the factors (Hyvärinen & Morioka, 2017; Hyvärinen et al., 2019) . Notation. Capital bold letters denote matrices and lower case bold letters denote vectors. The set of integers from 1 to n is denoted by [n] . We write ∥•∥ for the Euclidean norm on vectors and the Frobenius norm on matrices. For a matrix A ∈ R k×m , ∥A∥ 2,1 = m j=1 ∥A :j ∥, and ∥A∥ 2,0 = m j=1 1 ∥A:j ∦ =0 , where 1 is the indicator function. The ground-truth parameter of the encoder function is θ, while that of the learned representation is θ. We follow this convention for all the parameters throughout. Table 1 in Appendix A summarizes all the notation.

2.1. DISENTANGLEMENT AND SPARSE BASE-PREDICTORS FOR IMPROVED GENERALIZATION

In this section, we compare the generalization performance of entangled and disentangled representations on sparse downstream tasks. We show that the maximum likelihood estimator (defined in Problem (1)) computed on linearly equivalent representations (entangled or disentangled) yield the same model (Proposition 1). However, disentangled representations have better generalization properties when combined with a sparse base-predictor (Proposition 2 and Figure 1 ). First, the learned representation f θ is assumed to be linearly equivalent to the ground-truth representation f θ , i.e. there exists an invertible matrix L such that, for all x ∈ X , f θ (x) = Lf θ (x). Note that despite being assumed linearly equivalent, the learned representation f θ might not be disentangled (Definition 1); in that case, we say the representation is linearly entangled. When we refer to a disentangled representation, we write L := DP . Roeder et al. (2021) have shown that many common methods learn representations identifiable up to linear equivalence, such as deep neural networks for classification, contrastive learning (Oord et al., 2018; Radford et al., 2021) and autoregressive language models (Mikolov et al., 2010; Brown et al., 2020a) . Consider the following maximum likelihood estimator (MLE):foot_0  Ŵ ( θ) n := arg max W (x,y)∈D log p(y; η = W f θ (x)) , where y denotes the label, D := {(x (i) , y (i) )} n i=1 is the dataset, p(y; η) is a distribution over labels parameterized by η ∈ R k , and Ŵ ∈ R k×m is the task-specific predictorfoot_1 . The following result shows that the maximum likelhood estimator defined in Problem ( 1) is invariant to invertible linear transformations of the features. Note that it is an almost direct consequence of the invariance of MLE to reparametrization (Casella & Berger, 2001, Thm. 7.2.10) . See Appendix A for a proof. be the solutions to Problem (1) with the representations f θ and f θ , respectively (which we assume are unique). If there exists an invertible matrix L such that, ∀x ∈ X , f θ (x) = Lf θ (x); then we have, ∀x ∈ X , Ŵ ( θ) n f θ (x) = Ŵ (θ) n f θ (x). Proposition 1 shows that the model p(y; 1) is independent of L, i.e., the model is the same for disentangled and linearly entangled representations. We thus expect both disentangled and linearly entangled representations to perform identically on downstream tasks. Ŵ ( θ) n f θ (x)) learned by Problem ( In what follows, we assume the data is generated according to the following process. Assumption 1 (Data generation process). The input-label pairs are i.i.d. samples from the distribution p(x, y) := p(y | x)p(x) with p(y | x) := p(y; W f θ (x)), where W ∈ R k×m is the ground-truth coefficient matrix. To formalize the hypothesis that only a subset of the features f θ (x) are actually useful to predict the target y, we assume that the ground-truth coefficient matrix W is column sparse, i.e., ∥ Ŵ ∥ 2,0 = ℓ < m. Under this assumption, it is natural to constrain the MLE as such: Ŵ ( θ,ℓ) n := arg max W (x,y)∈D log p(y; W f θ (x)) s.t. ∥ W ∥ 2,0 ≤ ℓ . The following proposition will help us understand how this additional constraint interacts with representations that are disentangled or linearly entangled. See Appendix A for a proof. Proposition 2 (Population MLE for Linearly Entangled Representations). Let Ŵ ( θ) ∞ be the solution of the population-based MLE, arg max W E p(x,y) log p(y; W f θ (x)) (assumed to be unique). Suppose f θ is linearly equivalent to f θ , and Assumption 1 holds, then, Ŵ ( θ) ∞ = W L -1 . From Proposition 2, one can see that if the representation f θ is disentangled, then ∥ Ŵ ( θ) ∞ ∥ 2,0 = ∥W (DP ) -1 ∥ 2,0 = ∥W ∥ 2,0 = ℓ. Thus, in that case, the sparsity constraint in Problem (2) does not exclude the population MLE estimator from its hypothesis class, and yields a decrease in the generalization gap (Bickel et al., 2009; Lounici et al., 2011a; Mohri et al., 2018) without biasing the estimator. Contrarily, when f θ is linearly entangled, the population MLE might have more nonzero columns than the ground-truth, and thus would be excluded from the hypothesis space of Problem (2), which, in turn, would bias the estimator. Empirical validation. We now present a simple simulated experiment to validate the above claim that disentangled representations coupled with sparsity regularization can have better generalization. (Lasso regression, Tibshirani 1996) and Ridge regression (Hoerl & Kennard, 1970) on both disentangled and linearly entangled representations. Lasso regression coupled with the disentangled representation obtains better generalization than the other alternatives when ℓ/m = 5% and when the number of samples is very small. We can also see that, disentanglement, sparsity regularization and sufficient sparsity in the ground-truth data generating process are necessary to see a significant improvement, in line with our discussion. Lastly, the performance of all methods converge to the same value when the number of samples grows. See Appendix D.1 for more details and discussion on the results.

2.2. DISENTANGLEMENT VIA SPARSE MULTITASK LEARNING

In Section 2.1, we argued that disentangled representations can improve generalization when combined with sparse base-predictors, but we did not provide an approach to learn them. We first provide a new identification result (Theorem 1, Section 2.2.1), which states that in the context of sparse multitask learning, sparse base-predictors yield disentangled representations. Then, in Section 2.2.2, we provide a practical way to learn disentangled representations motivated by our identifiability result. Throughout this section, we assume the learner is given a set of T datasets {D 1 , . . . , D T } where each dataset D t := {(x (t,i) , y (t,i) )} n i=1 consists of n couples of input x ∈ R d and label y ∈ Y. The set of labels Y might contain either class indices or real values, depending on whether we are concerned with classification or regression tasks.

2.2.1. IDENTIFIABILITY ANALYSIS

We now present the main theoretical result of our work which shows how learning a shared representation across tasks while penalizing the task-specific base-predictor to be sparse can induce disentanglement. Our theory relies on the following ground-truth data generating process: Assumption 2 (Ground-truth data generating process). For each task t, the dataset D t is made of i.i.d. samples from the distribution p(x, y | W (t) ) := p(y | x, W (t) )p(x | W (t) ) with p(y | x, W (t) ) := p(y; W (t) f θ (x)) , where W (t) ∈ R k×m is the task-specific ground-truth coefficient matrix. Moreover, the matrices W (t) are i.i.d. samples from some probability measure P W with support W. Also, for all W ∈ W, the support of p(x | W ) is X ⊆ R d (fixed across tasks). The above assumption states that (i) the ground-truth coefficient matrices W (t) are task-specific while the representation f θ is shared across all the tasks, (ii) the task-specific W (t) are sampled i.i.d. from some distribution P W , and (iii) the support of x is shared across tasks. Assumption 3 (Identifiability of η). The parameter η is identifiable from p(y; η), i.e. ∀y; p(y; η) = p(y; η) =⇒ η = η. This property holds, e.g., when p(y; η) is a Gaussian in the usual µ, σ 2 parameterization. Generally, it also holds for minimal parameterizations of exponential families (Wainwright & Jordan, 2008) . The following assumption requires the ground-truth representation f θ (x) to vary enough such that its image cannot be trapped inside a proper subspace. The red distribution satisfies the assumption, but the blue and orange distributions do not. The red lines are level sets of a Gaussian distribution with full rank covariance. The blue line represents the support of a Gaussian distribution with a low rank covariance. The orange dots represents a distribution with finite support. The green vector a shows that the condition is violated for both the blue and the orange distribution, since, in both cases, W 1,S a = 0 (orthogonal) with probability greater than zero. Assumption 4 (Sufficient representation variability). There exists x (1) , . . . , x (m) ∈ X such that the matrix F := [f θ (x (1) ), . . . , f θ (x (m) )] is invertible. The following assumption requires that the support of the distribution P W is sufficiently rich. Assumption 5 (Sufficient task variability). There exists W (1) , . . . , W (m) ∈ W and row indices i 1 , . . . , i m ∈ [k] such that the rows W (1) i1,: , . . . , W im,: are linearly independent. Under Assumptions 2 to 5 the representation f θ is identifiable up to linear equivalence (see Theorem 2 in Appendix B). Similar results where shown by Roeder et al. (2021) ; Ahuja et al. (2022c) . The next assumptions will guarantee disentanglement. In order to formalize the intuitive idea that most tasks do not require all features, we will denote by S (t) the support of the matrix W (t) , i.e. S (t)  := {j ∈ [m] | W (t) :j ̸ = 0}. In other words, S (t) is the set of features which are useful to predict y in the t-th task; note that it is unknown to the learner. For our analysis, we decompose P W as P W = S∈P([m]) p(S)P W |S , where P([m]) is the collection of all subsets of [m], p(S) is the probability that the support of W is S and P W |S is the conditional distribution of W given that its support is S. Let S be the support of the distribution p(S), i.e. S := {S ∈ P([m]) | p(S) > 0}. The set S will have an important role in Assumption 7 & Theorem 1. The following assumption requires that P W |S does not concentrate on certain proper subspaces. Assumption 6 (Intra-support sufficient task variability). For all S ∈ S and all a ∈ R |S| \0, P W |S {W ∈ R k×m | W :S a = 0} = 0. We illustrate the above assumption in the simpler case where k = 1. For instance, Assumption 6 holds when the distribution of W 1,S | S has a density w.r.t. the Lebesgue measure on R |S| , which is true for example when W 1,S | S ∼ N (0, Σ) and the covariance matrix Σ is full rank (red distribution in Figure 2 ). However, if Σ is not full rank, the probability distribution of W 1,S | S concentrates its mass on a proper linear subspace V ⊊ R |S| , which violates Assumption 6 (blue distribution in Figure 2 ). Another important counter-example is when P W |S concentrates some of its mass on a point W (0) , i.e. P W |S {W (0) } > 0 (orange distribution in Figure 2 ). Interestingly, there are distributions over W 1,S | S that do not have a density w.r.t. the Lebesgue measure, but still satisfy Assumption 6. This is the case, e.g., when W 1,S | S puts uniform mass over a (|S| -1)dimensional sphere embedded in R |S| and centered at zero. See Appendix B.2 for a justification. The following assumption requires that the support S of p(S) is "rich enough". Assumption 7 (Sufficient support variability). For all j ∈ [m], S∈S|j̸ ∈S S = [m] \ {j}. Intuitively, Assumption 7 requires that, for every feature j, one can find a set of tasks such that their supports cover all features except j itself. Figure 3 shows an example of S satisfying Assumption 7. Removing the latter would only yield partial disentanglement (Lachapelle & Lacoste-Julien, 2022) . We are now ready to show the main theoretical result of this work, which provides a bi-level optimization problem for which the optimal representations are guaranteed to be disentangled. It as-sumes infinitely many tasks are observed, with task-specific ground-truth matrices W sampled from P W . We denote by Ŵ (W ) the task-specific estimator of W . See Appendix B.1 for a proof. Note that we suggest a tractable relaxation in Section 2.2.2. Theorem 1 (Sparse multi-task learning for disentanglement). Let θ be a minimizer of min θ E P W E p(x,y|W ) -log p(y; Ŵ (W ) f θ (x)) s.t. ∀ W ∈ W, Ŵ (W ) ∈ arg min W s.t. || W ||2,0≤||W ||2,0 E p(x,y|W ) -log p(y; W f θ (x)) . (3) Then, under Assumptions 2 to 7, f θ is disentangled w.r.t. f θ (Definition 1). Intuitively, this optimization problem effectively selects a representation f θ that (i) allows a perfect fit of the data distribution, and (ii) allows the task-specific estimators Ŵ (W ) to be as sparse as the ground-truth W . With the same disentanglement guarantees, Theorem 4 in Appendix B presents a variation of Problem (3) which enforces the weaker constraint E P W ∥ Ŵ (W ) ∥ 2,0 ≤ E P W ∥W ∥ 2,0 , instead of ∥ Ŵ (W ) ∥ 2,0 ≤ ∥W ∥ 2,0 for each task W individually.

2.2.2. TRACTABLE BILEVEL OPTIMIZATION PROBLEMS FOR SPARSE MULTITASK LEARNING

Problem (3) was shown to yield a disentangled representation (Theorem 1), but is intractable due to the L 2,0 -seminorm. Thus we use the L 2,1 convex relaxation of the L 2,0 -seminorm, which is also known to promote group sparsity (Obozinski et al., 2006; Argyriou et al., 2008; Lounici et al., 2009) : min θ - 1 T n T t=1 (x,y)∈Dt log p(y; Ŵ (t) f θ (x)) s.t. ∀ t ∈ [T ], Ŵ (t) ∈ arg min W - 1 n (x,y)∈Dt log p(y; W f θ (x)) + λ t || W || 2,1 . Following Bengio (2000) ; Pedregosa (2016) , one can compute the (hyper)gradient of the outer function using implicit differentiation, even if the inner optimization problem is non-smooth (Bertrand et al., 2020; Bolte et al., 2021; Malézieux et al., 2022; Bolte et al., 2022) . Once the hypergradient is computed, one can optimize Problem (4) using usual first-order methods (Wright & Nocedal, 1999) . Note that the quantity Ŵ (t) f θ (x) is invariant to simultaneous rescaling of Ŵ (t) by a scalar and of f θ (x) by its inverse. Thus, without constraints on f θ (x), ∥ Ŵ (t) ∥ 2,1 can be made arbitrarily small. This is a usual problem in sparse dictionary learning (Kreutz-Delgado et al., 2003; Mairal et al., 2008; 2009; 2011) , where unit-norm constraints are usually imposed on the column of the dictionary. Here, since f θ is parametrized by a neural network, we suggest to apply batch or layer normalization (Ioffe & Szegedy, 2015; Ba et al., 2016) to control its norm. Since the number of relevant features might be task-dependent, Problem (4) has one regularization hyperparameter λ t per task. To limit the number of hyperparameters in practice, we select λ t := λ for all t ∈ [T ].

2.3. LINK WITH META-LEARNING

In the setting known as meta-learning (Finn et al., 2017) , for a large number of tasks T , we are given training datasets D train t , which usually contains a small number of samples n. Unlike in the multi-task setting though (i.e., unlike in Section 2.2), we are also given separated test datasets D test t to evaluate how well the learned model generalizes to new test samples. In meta-learning, the goal is to learn a training procedure which will generalize well on out-of-distribution tasks. The bi-level formulation Problem (4) is closely related to metric-based meta-learning (Snell et al., 2017; Bertinetto et al., 2019) , where a shared representation f θ is learned across all tasks. The representation is jointly learned with simple task-specific classifiers, which are usually optimizationbased classifiers, such as support-vector machines. Formally, metric-based meta-learning can be formulated as follows min θ T t=1 (x,y)∈D test t L out Ŵ (t) θ ; f θ (x), y s.t. Ŵ (t) θ ∈ arg min W (x,y)∈D train t L in W ; f θ (x), y . Inspired by Lee et al. (2019) , where the base-classifiers were multiclass support-vector machines (SVMs, Crammer & Singer 2001) , we propose to use group Lasso penalized multiclass SVMs, in order to introduce sparsity in the base-learners, with Y ∈ R n×k the one-hot encoding of y ∈ R n : L in (W ; f θ (x)), y) := max l∈[k] (W yi: -W l: ) • f θ (x) -Y :l + λ1 n ∥W ∥ 2,1 + λ2 2n ∥W ∥ 2 . (5) In few-shot learning settings, the number of features m is usually much larger than the number of samples n (in Lee et al. 2019 , m = 1.6•10 4 and n ≤ 25). In such scenarios, SVMs-like problems are usually solved through their dual (Boyd et al., 2004, Chap. 5 ) problems, for computational (Hsieh et al., 2008) and theoretical (Shalev-Shwartz & Zhang, 2012) benefits. Proposition 3. (Dual Group Lasso Soft-Margin Multiclass SVM.) The dual of the inner problem with L in as defined in (5) writes min Λ∈R n×k 1 λ 2 m j=1 ∥BST (Y -Λ) ⊤ F :j , λ 1 ∥ 2 + ⟨Y , Λ⟩ + n i=1 1 k l=1 Λ il =1 + n i=1 k l=1 1 Λ il ≥0 , with BST the block soft-thresholding operator: BST : (a, τ ) → (1 -τ /∥a∥) + a, F ∈ R n×m the concatenation of {f θ (x)} (x,y)∈D train . In addition, the primal-dual link writes, for all j ∈ [m], W :j = BST (Y -Λ) ⊤ F :j , λ 1 /λ 2 . Proof of Proposition 3 can be found in Appendix C.1. The objective of Problem ( 6) is composed of a smooth term and block separable non-smooth term, hence it can be solved efficiently using proximal block coordinate descent (Tseng, 2001) . As stated in Section 2.2, argmin differentiation of the solution of Problem ( 6) can be done using implicit differentiation (Bertrand et al., 2022) . Although Theorem 1 is not directly applicable to the meta-learning formulation proposed in this section, we conjecture that similar techniques could be reused to prove an identifiability result in this setting.

3. RELATED WORK

Disentanglement. Since the work of Bengio et al. (2013) , many methods have been proposed to learn disentangled representations based on various heuristics (Higgins et al., 2017; Chen et al., 2018; Kim & Mnih, 2018; Kumar et al., 2018; Bouchacourt et al., 2018) . Following the work of Locatello et al. (2019) , which highlighted the lack of identifiability in modern deep generative models, many works have proposed more or less weak forms of supervision motivated by identifiability analyses (Locatello et al., 2020a; Klindt et al., 2021; Von Kügelgen et al., 2021; Ahuja et al., 2022a; c; Zheng et al., 2022) . A similar line of work have adopted the causal representation learning perspective (Lachapelle et al., 2022; Lachapelle & Lacoste-Julien, 2022; Lippe et al., 2022b; a; Ahuja et al., 2022b; Yao et al., 2022; Brehmer et al., 2022) . The problem of identifiability was well known among the independent component analysis (ICA) community (Hyvärinen et al., 2001; Hyvärinen & Pajunen, 1999 ) which came up with solutions for general nonlinear mixing functions by leveraging auxiliary information (Hyvärinen & Morioka, 2016; 2017; Hyvärinen et al., 2019; Khemakhem et al., 2020a; b) . Another approach is to consider restricted hypothesis classes of mixing functions (Taleb & Jutten, 1999; Gresele et al., 2021) . Contrarily to most of the above works, we do not assume that the inputs x are generated by transforming a latent random variable z through a bijective decoder g. Instead, we assume the existence of a not necessarily bijective ground-truth feature extractor f θ (x) from which the labels can be predicted using only a subset of its components in every tasks (Assumption 2). Many of these works make assumptions about the distribution of latent factors, e.g., (conditional) independence, exponential family or other parametric assumptions. In contrast, we make comparatively weaker assumptions on the support of the ground-truth features (Assumption 4), which are allowed to present dependencies (Section 4). Locatello et al. (2020b) proposed a semi-supervised learning approach to disentangle in cases where a few samples are labelled with the factors of variations themselves. This is different from our approach as the labels that we consider can be sampled from some p(y; W f θ (x)), which is more general. Ahuja et al. (2022c) consider a setting similar to ours, but they rely on the independence and non-gaussianity of the latent factors for disentanglement using linear ICA. Multi-task, transfer & invariant learning. The statistical advantages of multi-task representation learning is well understood (Lounici et al., 2011a; b; Maurer et al., 2016) . However, apart from Zhang et al., 2019; Miladinović et al., 2019; Dittadi et al., 2021) and negative results (Locatello et al., 2019; Montero et al., 2021) . Invariant risk minimization (Arjovsky et al., 2020; Ahuja et al., 2020; Krueger et al., 2021; Lu et al., 2021) aims at learning a representation that elicits a task-invariant base-predictor. This differs from our approach which learns base-predictors that are task-specific . Dictionary learning and sparse coding. We contrast our approach, which jointly learns a dense representation and sparse base-predictors (Problem (4)), with the line of work which consists in learning sparse representations (Chen et al., 1998; Gribonval & Lesage, 2006) . For instance, sparse dictionary learning (Mairal et al., 2009; 2011; Maurer et al., 2013) is an unsupervised technique which aims at learning sparse representations that refer to atoms of a learned dictionary. Contrarily to our method which computes the representation of a single input x by evaluating a function approximator f θ , in sparse dictionary learning, the representation of a single input is computed by minimizing a reconstruction loss. In the case of supervised dictionary learning (Mairal et al., 2008) , an additional (potentially expressive) classifier is learned. This large literature has lead to a wide variety of estimators: for instance, Mairal et al. (2008, Eq. 4) , which minimizes the sum of the classification error and the approximation error of the code, or Mairal et al. (2011) ; Malézieux et al. (2022) , which introduce bi-level formulations which shares similarities with our formulations.

4. EXPERIMENTS

Semi-real experiments on 3D Shapes. We now illustrate Theorem 1 by applying Problem (4) to tasks generated using the 3D Shapes dataset (Burgess & Kim, 2018) . Data generation. For all tasks t, the labelled dataset D t = {(x (t,i) ), y (t,i) )} n i=1 is generated by first sampling the ground-truth latent variables z (t,i) := f θ (x (t,i) ) i.i.d. according to some distribution p(z), while the corresponding input is obtained doing x (t,i)  := f -1 θ (z (t,i) ) (f θ is invertible in 3D Shapes). Then, a sparse weight vector w (t) is sampled randomly to compute the labels of each example as y (t,i) := w (t) • x (t,i) + ϵ (t,i) , where ϵ (t,i) is independent Gaussian noise. Figure 4 explores various choice of p(z), i.e. by varying the level of correlation between the latent variables and by varying the level of noise on the ground-truth latents. See Appendix D.2 for more details about the data generating process. Algorithms. In this setting where p(y; η) is a Gaussian with fixed variance, the inner problem of Problem (4) amounts to Lasso regression, we thus refer to this approach as inner-Lasso. We also evaluate a simple variation of Problem (4) in which the L 1 norm is replaced by an L 2 norm, and refer to it as inner-Ridge. In addition we evaluate the representation obtained by performing linear 0 2500 5000 7500 10000 12500 15000 ICA (Comon, 1992) on the representation learned by inner-Ridge: the case λ = 0 corresponds to the approach of Ahuja et al. (2022c) . Discussion. Figure 4 reports disentanglement performance of the three methods, as measured by the mean correlation coefficient, or MCC (Hyvärinen & Morioka, 2016; Khemakhem et al., 2020a ) (Appendix D.2). In all settings, inner-Lasso obtains high MCC for some values of λ, being on par or surpassing the baselines. As the theory suggests, it is robust to high levels of correlations between the latents, as opposed to inner-Ridge with ICA which is very much affected by strong correlations (since ICA assumes independence). We can also see how additional noise on the latent variables hurts inner-Ridge with ICA while leaving inner-Lasso unaffected. Figure 6 (Eastwood & Williams, 2018) . Few-shot learning experiments. Despite the lack of ground-truth latent factors in standard few-shot learning benchmarks, we also evaluate our meta-learning objective introduced in Section 2.3, using the dual formulation of the group Lasso penalized SVM as our base-learner, on the miniImageNet dataset (Vinyals et al., 2016) . The objective of this experiment is to show that the sparse formulation of the meta-learning objective is capable of reaching similar levels of performance, while using a fraction of the features. Details about the experimental settings are provided in Appendix D.3. Discussion. In Figure 5 (left), we report how frequently the learned features are used by the baselearner on meta-training tasks; the gradual decrease in usage suggests that the features are reused in different contexts, across different tasks. We also observe (Figure 5 , right) that adding sparsity to the base learner may also improve performance on meta-training tasks, while only using a fraction of all the features available in the learned representation, supporting our observations in Section 2.1 on the effect of sparsity on generalization on natural images (see Appendix D.3 for further discussion about how this still tests generalization). We also observe that some level of sparsity improves the performance on novel meta-test tasks, albeit to a smaller extent.

5. CONCLUSION

In this work, we investigated the synergies between sparsity, disentanglement and generalization. We showed that when the downstream task can be solved using only a fraction of the factors of variations, disentangled representations combined with sparse base-predictors can improve generalization (Section 2.1). Our novel identifiability result (Theorem 1) sheds light on how, in a multi-task setting, sparsity regularization on the task-specific predictors can induce disentanglement. This led to a practical bi-level optimization problem that was shown to yield disentangled representations on regression tasks based on the 3D Shapes dataset. Finally, we explored a meta-learning formulation extending this approach, and showed how sparse base-learners can help with generalization, while only using a small fraction of the features. be the solutions to Problem (1) with the representations f θ and f θ , respectively (which we assume are unique). If there exists an invertible matrix L such that, ∀x ∈ X , f θ (x) = Lf θ (x); then we have, ∀x ∈ X , Ŵ ( θ) n f θ (x) = Ŵ (θ) n f θ (x). Proof. By definition of Ŵ ( θ) , we have that, for all Ŵ ∈ R k×m , (x,y)∈D log p(y; Ŵ ( θ) f θ (x)) ≥ (x,y)∈D log p(y; Ŵ f θ (x)) (7) (x,y)∈D log p(y; Ŵ ( θ) Lf θ (x)) ≥ (x,y)∈D log p(y; Ŵ Lf θ (x)) . Because R k×m L = R k×m , we have that, for all Ŵ ∈ R k×m , (x,y)∈D log p(y; Ŵ ( θ) Lf θ (x)) ≥ (x,y)∈D log p(y; Ŵ f θ (x)) , which is to say that Ŵ (θ) = Ŵ ( θ) L, or put differently, Ŵ ( θ) = Ŵ (θ) L -1 . It implies Ŵ ( θ) f θ (x) = Ŵ (θ) L -1 Lf θ (x) = Ŵ (θ) f θ (x) , which is what we wanted to show. Proposition 2 (Population MLE for Linearly Entangled Representations). Let Ŵ ( θ) ∞ be the solution of the population-based MLE, arg max W E p(x,y) log p(y; W f θ (x)) (assumed to be unique). Suppose f θ is linearly equivalent to f θ , and Assumption 1 holds, then, Ŵ ( θ) ∞ = W L -1 . Proof. By definition of Ŵ ( θ) ∞ , we have that, for all W ∈ R k×m , E p(x,y) log p(y; Ŵ ( θ) ∞ f θ (x)) ≥ E p(x,y) log p(y; W f θ (x)) (11) E p(x,y) log p(y; Ŵ ( θ) ∞ Lf θ (x)) ≥ E p(x,y) log p(y; W Lf θ (x)) . In particular, the inequality holds for W := W L -1 , which yields E p(x,y) log p(y; Ŵ ( θ) ∞ Lf θ (x)) ≥ E p(x,y) log p(y; W f θ (x)) (13) 0 ≥ E p(x,y) log p(y; W f θ (x)) -log p(y; Ŵ ( θ) ∞ Lf θ (x)) (14) 0 ≥ E p(x) KL(p(y; W f θ (x)) || p(y; Ŵ ( θ) ∞ Lf θ (x))) . Since the KL is always non-negative, we have that, E p(x) KL(p(y; W f θ (x)) || p(y; Ŵ ( θ) ∞ Lf θ (x))) = 0 , which in turn implies E p(x,y) log p(y; Ŵ ( θ) ∞ Lf θ (x)) = E p(x,y) log p(y; W f θ (x)) (17) E p(x,y) log p(y; Ŵ ( θ) ∞ Lf θ (x)) = E p(x,y) log p(y; W L -1 Lf θ (x)) E p(x,y) log p(y; Ŵ ( θ) ∞ f θ (x)) = E p(x,y) log p(y; W L -1 f θ (x)) (20) Since the solution to the population MLE from Proposition 2 is assumed to be unique, this equality holds if and only if Ŵ ( θ) ∞ = W L -1 .

B IDENTIFIABILITY THEORY

The following lemma will be important for proving Theorem 3. The argument is taken from Lachapelle et al. (2022) . Lemma 1 (Sparsity pattern of an invertible matrix contains a permutation). Let L ∈ R m×m be an invertible matrix. Then, there exists a permutation σ such that L i,σ(i) ̸ = 0 for all i. Proof. Since the matrix L is invertible, its determinant is non-zero, i.e. det(L) := σ∈Sm sign(σ) m i=1 L i,σ(i) ̸ = 0 , ( ) where S m is the set of m-permutations. This equation implies that at least one term of the sum is non-zero, meaning there exists σ ∈ S m such that for all i ∈ [m], L i,σ(i) ̸ = 0. For all W ∈ W, we are going to denote by Ŵ (W ) some estimator of W . The following result provides conditions under which if Ŵ (W ) allows a perfect fit of the ground-truth distribution p(y | x, W ), then the representation f θ and the parameter W are identified up to an invertible linear transformation. Many works have showed similar results in various context (Hyvärinen & Morioka, 2016; Khemakhem et al., 2020a; Roeder et al., 2021; Ahuja et al., 2022c) . We reuse some of their proof techniques. Theorem 2 (Linear identifiability). Let Ŵ (•) : W → R k×m . Suppose Assumptions 2 to 5 hold and that, for all W ∈ W, x ∈ X and y ∈ Y, the following holds p(y; Ŵ (W ) f θ (x)) = p(y; W f θ (x)) . ( ) Then, there exists an invertible matrix L ∈ R m×m such that, for all x ∈ X , f θ (x) = Lf θ (x) and such that, for all W ∈ W, . . . Ŵ (W ) = W L Proof. By Assumption 3, Equation (23) implies that W f θ (x) = Ŵ (W ) f θ (x). W (dz) i dz ,:     . Construct analogously Û :=     Ŵ (W (1) ) i1,: . . . Ŵ (W (dz ) ) i dz ,:     . This allows us to write U f θ (x) = Û f θ (x). Left-multiplying by U -1 on both sides yields f θ (x) = Lf θ (x), where L := U -1 Û . Using the invertible matrix F from Assumption 4, we can thus write F = L F where we defined F := [f θ (x (1) ), • • • , f θ (x (dz) )]. Since F is invertible, so are L and F . By substituting F = L F in W F = Ŵ (W ) F , we obtain W L F = Ŵ (W ) F . By right-multiplying both sides by F -1 , we obtain W L = Ŵ (W ) . The following theorem is where most of the theoretical contribution of this work lies. Note that Theorem 1, from the main text, is a straightforward application of this result. Theorem 3. (Disentanglement via task sparsity) Let Ŵ (•) : W → R k×m . Suppose Assumptions 3 to 7 hold and that, for all W ∈ W, x ∈ X and y ∈ Y, the following holds p(y; Ŵ (W ) f θ (x)) = p(y; W f θ (x)) . ( ) Morevover, assume that E∥ Ŵ (W ) ∥ 2,0 ≤ E∥W ∥ 2,0 , where both expectations are taken w.r.t. P W and ∥W ∥ 2,0 := m j=1 1(W :j ̸ = 0) with 1(•) the indicator function. Then, f θ is disentangled w.r.t. f θ (Definition 1). Proof. First of all, by Assumptions 3 to 5, we can apply Theorem 2 to conclude that f θ (x) = Lf θ (x) and W L = Ŵ (W ) for some invertible matrix L. We can thus write E∥W L∥ 2,0 ≤ E∥W ∥ 2,0 . We can write E∥W ∥ 2,0 = E p(S) E[ m j=1 1(W :j ̸ = 0) | S] (24) = E p(S) m j=1 E[1(W :j ̸ = 0) | S] (25) = E p(S) m j=1 P W |S [W :j ̸ = 0] (26) = E p(S) m j=1 1(j ∈ S) , where the last step follows from the definition of S. We now perform similar steps for E∥W L∥ 2,0 :  E∥W L∥ 2,0 = E p(S) E[ m j=1 1(W L :j ̸ = 0) | S] (28) = E p(S) m j=1 E[1(W L :j ̸ = 0) | S] (29) = E p(S) m j=1 P W |S [W L :j ̸ = 0] (30) = E p(S) m j=1 P W |S [W :S L S,j ̸ = 0] . P W |S [W :S L S,j ̸ = 0] = 1 -1(S ∩ N j = ∅) (33) = 1(S ∩ N j ̸ = ∅) , which allows us to write E∥W L∥ 2,0 = E p(S) m j=1 1(S ∩ N j ̸ = ∅) . ( ) We thus have that E∥W L∥ 2,0 ≤ E∥W ∥ 2,0 E p(S) m j=1 1(S ∩ N j ̸ = ∅) ≤ E p(S) m j=1 1(j ∈ S) . ( ) Since L is invertible, by Lemma 1, there exists a permutation σ : [m] → [m] such that, for all j ∈ [m], L j,σ(j) ̸ = 0. In other words, for all j ∈ [m], j ∈ N σ(j) . Of course we can permute the terms of the l.h.s. of eq. ( 37), which yields E p(S) m j=1 1(S ∩ N σ(j) ̸ = ∅) ≤ E p(S) m j=1 1(j ∈ S) (38) E p(S) m j=1 1(S ∩ N σ(j) ̸ = ∅) -1(j ∈ S) ≤ 0 . ( ) We notice that each term 1(S ∩ N σ(j) ̸ = ∅) -1(j ∈ S) ≥ 0 since whenever j ∈ S, we also have that j ∈ S ∩ N σ(j) (recall j ∈ N σ(j) ). Thus, the l.h.s. of eq. ( 39) is a sum of non-negative terms which is itself non-positive. This means that every term in the sum is zero: ∀S ∈ S, ∀j ∈ [m], 1(S ∩ N σ(j) ̸ = ∅) = 1(j ∈ S) . (40) Importantly, ∀j ∈ [m], ∀S ∈ S, j ̸ ∈ S =⇒ S ∩ N σ(j) = ∅ , and since S ∩ N σ(j) = ∅ ⇐⇒ N σ(j) ⊆ S c we have that ∀j ∈ [m], ∀S ∈ S, j ̸ ∈ S =⇒ N σ(j) ⊆ S c (42) ∀j ∈ [m], N σ(j) ⊆ S∈S|j̸ ∈S S c . ( ) By Assumption 7, we have that S∈S|j̸ ∈S S = [m] \ {j}. By taking the complement on both sides and using De Morgan's law, we get S∈S|j̸ ∈S S c = {j}, which implies that N σ(j) = {j} by Equation ( 43). Thus, L = DP where D is an invertible diagonal matrix and P is a permutation matrix.

B.1 PROOF OF THEOREM 1

Before presenting Theorem 1 from the main text, we first present a variation of it where we constrain E∥ Ŵ (W ) ∥ 2,0 to be smaller than E∥W ∥ 2,0 . We note that this is weaker than imposing ∥ Ŵ (W ) ∥ 2,0 ≤ ∥W ∥ 2,0 for all W ∈ W, as is the case in Problem (3) of Theorem 1. Theorem 4 (Sparse multitask learning for disentanglement). Let θ be a minimizer of min θ E P W E p(x,y|W ) -log p(y; Ŵ (W ) f θ (x)) s.t. ∀ W ∈ W, Ŵ (W ) ∈ arg min W E p(x,y|W ) -log p(y; W f θ (x)) E P W ∥ Ŵ (W ) ∥ 2,0 ≤ E P W ∥W ∥ 2,0 . (44) Then, under Assumptions 2 to 7, f θ is disentangled w.r.t. f θ (Definition 1). Proof. First, notice that 0 ≤ E P W E p(x|W ) KL(p(y; W f θ (x)) || p(y; Ŵ (W ) f θ (x))) E P W E p(x,y|W ) -log p(y; W f θ (x)) ≤ E P W E p(x,y|W ) -log p(y; Ŵ (W ) f θ (x)) . For a fixed value of x and W , it is well known that KL(p(y; W f θ (x)) || p(y; Ŵ (W ) f θ (x))) = 0 if and only if, for all y ∈ Y, p(y; W f θ (x)) = p(y; Ŵ (W ) f θ (x)). By Assumption 3, this is equivalent to W f θ (x) = Ŵ (W ) f θ (x). Thus, for the equality to hold in eq. ( 45), we need W f θ (x) = Ŵ (W ) f θ (x) everywhere. Of course, the global minimum can be achieved by respecting E P W ∥ Ŵ (W ) ∥ 2,0 ≤ E P W ∥W ∥ 2,0 , simply by setting θ := θ and Ŵ (W ) := W . The above implies that if θ is some minimizer of Problem (44), we must have that W f θ (x) = Ŵ (W ) f θ (x) everywhere and E P W || Ŵ (W ) || 0 ≤ E P W ||W || 0 . Thus, Theorem 3 implies the desired conclusion. Based on Theorem 4, we can slightly adjust the argument to prove Theorem 1 from the main text. Theorem 1 (Sparse multi-task learning for disentanglement). Let θ be a minimizer of min θ E P W E p(x,y|W ) -log p(y; Ŵ (W ) f θ (x)) s.t. ∀ W ∈ W, Ŵ (W ) ∈ arg min W s.t. || W ||2,0≤||W ||2,0 E p(x,y|W ) -log p(y; W f θ (x)) . Then, under Assumptions 2 to 7, f θ is disentangled w.r.t. f θ (Definition 1). Proof. The first part of the argument in the proof of Theorem 4 applies here as well, meaning: for the equality to hold in eq. ( 45), we need W f θ (x) = Ŵ (W ) f θ (x) everywhere. This global minimum can be achieved by respecting ∥ Ŵ (W ) ∥ 2,0 ≤ ∥W ∥ 2,0 for all W ∈ W simply by setting θ := θ and Ŵ (W ) := W . This means that if θ is some minimizer of Problem (3), we must have that W f θ (x) = Ŵ (W ) f θ (x) holds everywhere and that, for all W ∈ W, ∥ Ŵ (W ) ∥ 2,0 ≤ ∥W ∥ 2,0 . Of course, this means  E P W || Ŵ (W ) || 0 ≤ E P W ||W || 0 , y ∈ R d , by h * (y) = sup x∈R d ⟨x, y⟩ -h(x). Definition 2. (Primal Group Lasso Soft-Margin Multiclass SVM.) The primal problem of the group Lasso soft-margin multiclass SVM is defined as min W ∈R k×m L in (W ; F , Y ) := n i=1 max l∈[k] (1 + (W yi: -W l: )F i: -Y il ) + λ 1 ∥W ∥ 2,1 + λ2 2 ∥W ∥ 2 (47) Proposition 3. (Dual Group Lasso Soft-Margin Multiclass SVM.) The dual of the inner problem with L in as defined in (5) writes min Λ∈R n×k 1 λ 2 m j=1 ∥BST (Y -Λ) ⊤ F :j , λ 1 ∥ 2 + ⟨Y , Λ⟩ + n i=1 1 k l=1 Λ il =1 + n i=1 k l=1 1 Λ il ≥0 , with BST the block soft-thresholding operator: BST : (a, τ ) → (1 -τ /∥a∥) + a, F ∈ R n×m the concatenation of {f θ (x)} (x,y)∈D train . In addition, the primal-dual link writes, for all j ∈ [m], W :j = BST (Y -Λ) ⊤ F :j , λ 1 /λ 2 . The primal objective 47 can be hard to minimize with modern solvers. Moreover in few-shot learning applications, the number of features m is usually much larger than the number of samples n (in Lee et al. 2019 , m = 1.6 • 10 4 and n ≤ 25), hence we solve the dual of Problem (47). Proof of Proposition 3. Let g : u → λ 1 ∥u∥ + λ2 2 ∥u∥ 2 . Proof of Proposition 3 is composed of the following lemmas. Lemma 2. i) The dual of Problem (47) is min Λ∈R n×k m j=1 g * ((Y -Λ) ⊤ F :j ) + ⟨Y , Λ⟩ s.t. ∀i ∈ [n], k l=1 Λ il = 1 , ∀i ∈ [n], l ∈ [k], Λ il ≥ 0 , where g * is the Fenchel conjugate of the function g. ii) The Fenchel conjugate of the function g writes ∀v ∈ R K , g * (v) = 1 λ 2 ∥BST(v, λ 1 )∥ 2 . ( ) Lemmas 4 i) and 4 ii) yields Proposition 3. Proof of Lemma 4 i). The Lagrangian of Problem (47) writes: L(W , ξ, Λ) = m j=1 g(W :j ) + i ξ i + n i=1 k l=1 (1 -ξ i -W yi: • F i: + W l: • F i: -Y il )Λ il . (50) ∂ ξ L(W , ξ, Λ) = 0 yields ∀i ∈ [n], k l=1 Λ il = 1. Then the Lagrangian rewrites min W min ξ L(W , ξ, Λ) = min W ,ξ m j=1 g(W :j ) + n i=1 ξ i + n i=1 k l=1 (-ξ i -W yi: • F i: + W l: • F i: -Y il )Λ il = m j=1 min W:j g(W :j ) - n i=1 k l=1 (F i: Y il -F i: Λ il )W l: =⟨(Y -Λ) ⊤ F:j ,W:j ⟩ =-g * ((Y -Λ) ⊤ F:j ) - n i=1 k l=1 Y il Λ il . Then the dual problem writes: min Λ∈R n×k m j=1 g * (Y -Λ) ⊤ F :j + ⟨Y , Λ⟩ s. t. ∀i ∈ [n] k l=1 Λ il = 1 , ∀i ∈ [n], l ∈ [k], Λ il ≥ 0 . ( ) Proof of Lemma 4 ii). Let h : u → ∥u∥ 2 + κ 2 ∥u∥ 2 . The proof of Lemma 4 i) is done using the following steps. Lemma 3. i) h * (v) = 1 2κ ∥v∥ 2 2 -κ 2 ∥•∥ 2 2 □∥•∥ 2 (v/κ). ii) κ 2 ∥•∥ 2 2 □∥•∥ 2 (v) = κ 2 ∥v∥ 2 2 -1 2κ ∥BST(κv, 1)∥ 2 . Proof of Lemma 4 i). With κ = λ 2 /λ 1 , the Fenchel transform of h : w → ∥w∥ 2 + κ∥w∥ 2 . h(u) = ∥u∥ 2 + κ 2 ∥u∥ 2 2 h * (v) = sup w v ⊤ w -∥w∥ 2 -κ 2 ∥w∥ 2 2 = 1 2κ ∥v∥ 2 2 + sup w -κ 2 ∥w -v/κ∥ 2 2 -∥w∥ 2 = 1 2κ ∥v∥ 2 2 -inf w κ 2 ∥w -v/κ∥ 2 2 + ∥w∥ 2 = 1 2κ ∥v∥ 2 2 -( κ 2 ∥•∥ 2 2 □∥•∥ 2 )(v/κ) . Proof of Lemma 4 ii). ( κ 2 ∥•∥ 2 2 □∥•∥ 2 )(v) = ( κ 2 ∥•∥ 2 2 □∥•∥ 2 ) * * (v) = ( 1 2κ ∥•∥ 2 2 + ι B2 ) * (v) = sup ∥w∥2≤1 v ⊤ w -1 2κ ∥w∥ 2 2 = κ 2 ∥v∥ 2 + sup ∥w∥2≤1 -1 2κ ∥κv -w∥ 2 2 = κ 2 ∥v∥ 2 -1 2κ ∥BST(κv, 1)∥ 2 2 . g * (u) = λ 1 h * (u/λ 1 ) = λ 1 2κ ∥BST(u/λ 1 , 1)∥ 2 = λ 2 1 2λ 2 ∥BST(u/λ 1 , 1)∥ 2 = 1 λ 2 ∥BST(u, λ 1 )∥ 2 . D EXPERIMENTAL DETAILS D.1 DISENTANGLED REPRESENTATION COUPLED WITH SPARSITY REGULARIZATION IMPROVES GENERALIZATION We consider the following data generating process: We sample the ground-truth features f θ (x) from a Gaussian distribution N (0, Σ) where Σ ∈ R m×m and Σ i,j = 0.9 |i-j| . Moreover, the labels are given by y = w • f θ (x) + ϵ where w ∈ R m , ϵ ∼ N (0, 0.04) and m = 100. The ground-truth weight vector w is sampled once from N (0, I m×m ) and mask some of its components to zero: we vary the fraction of meaningful features (ℓ/m) from very sparse (ℓ/m = 5%) to less sparse (ℓ/m = 80%) settings. For each case, we study the sample complexity by varying the number of training samples from 25 to 150, but evaluating the generalization performance on a larger test dataset (1000 samples). To generate the entangled representations, we multiply the true latent variables f θ (x) by a randomly sampled orthogonal matrix L, i.e., f θ (x) := Lf θ (x). For the disentangled representation, we simply consider the true latents, i.e. f θ (x) := f θ (x). Note that in principle we could have considered an invertible matrix L that is not orthogonal for the linearly entangled representation and a component-wise rescaling for the disentangled representation. The advantage of not doing so and opting for our approach is that the conditioning number of the covariance matrix of f θ (x) is the same for both the entangled and the disentangled, hence offering a fairer comparison. For both the case of entangled and disentangled representation, we solve the regression problem with Lasso and Ridge regression, where the associated hyperparameters (regularization strength) were inferred using 5-fold cross validation on the input training dataset. Using both lasso and ridge regression would help us to show the effect of encouraging sparsity. In Figure 1 for the sparsest case (ℓ/m = 5%), we observe that that Disentangled-Lasso approach has the best performance when we have less training samples, while the Entangled-Lasso approach performs the worst. As we increase the number of training samples, the performance of Entangled-Lasso approaches that of Disentangled-Lasso, however, learning under the Disentangled-Lasso approach is sample efficient. Disentangled-Lasso obtains R 2 greater than 0.5 with only 25 training samples, while other approaches obtain R 2 close to zero. Also, Disentagled-Lasso converges to the optimal R 2 using only 50 training samples, while Entangled-Lasso does the same with 150 samples samples. Note that the improvement due to disentanglement does not happen for the case of ridge regression as expected and there is no of a difference between the methods Disentangled-Ridge and Entangled-Ridge because the L2 norm is invariant to orthogonal transformation. Also, having sparsity in the underlying task is important. Disentangled-Lasso shows the max improvement for the case of ℓ/m = 5%, with the gains reducing as we decrease the sparsity in the underlying task (l/m = 80%). -30, 30] ). These are the factors we aim to disentangle. We standardize them to have mean 0 and variance 1. We denote by Z ⊂ R 6 , the set of all possible latent factor combinations. In our framework, this corresponds to the support of the ground-truth features f θ (x). We note that the points in Z are arranged in a grid-like fashion in R 6 . Task generation. For all tasks t, the labelled dataset D t = {(x (t,i) ), y (t,i) )} n i=1 is generated by first sampling the ground-truth latent variables z (t,i) := f θ (x (t,i) ) i.i.d. according to some distribution p(z) over Z, while the corresponding input is obtained doing x (t,i) := f -1 θ (z (t,i) ) (f θ is invertible in 3D Shapes). Then, a sparse weight vector w (t) is sampled randomly by doing w (t) := w(t) ⊙ s (t) , were ⊙ is the Hadamard (component-wise) product, w(t) ∼ N (0, I) and s ∈ {0, 1} 6 is a binary vector with independent components sampled from a Bernoulli distribution with (p = 0.5). Then, the labels are computedfor each example as y (t,i) := w (t) • x (t,i) + ϵ (t,i) , where ϵ (t,i) is independent Gaussian noise. In every tasks, the dataset has size n = 50. New tasks are generated continuously as we train. Figures 4 and 6 explores various choices of p(z), i.e. by varying the level of correlation between the latent variables and by varying the level of noise on the ground-truth latents. Noise on latents. To make the dataset slightly more realistic, we get rid of the artificial grid-like structure of the latents by adding noise to it. This procedure transforms Z into a new support Z α , where α is the noise level. Formally, Z α := z∈Z {z + u z } where the u z are i.i.d samples from the uniform over the hypercube -α ∆z 1 2 , α ∆z 1 2 × -α ∆z 2 2 , α ∆z 2 2 × . . . × -α ∆z 6 2 , α ∆z 6 2 , where ∆z i denotes the gap between contiguous values of the factor z i . When α = 0, no noise is added and the support Z is unchanged, i.e., Z 1 = Z. As long as α ∈ [0, 1], contiguous points in Z cannot be interchanged in Z α . We also clarify that the ground-truth mapping f θ is modified to f θ,α consequently: for all x ∈ X , f θ,α (x) := f θ (x) + u z . We emphasize that the u z are sampled only once such that f θ,α (x) is actually a deterministic mapping. Varying correlations. To verify that our approach is robust to correlations in the latents, we construct p(z) as follows: We consider a Gaussian density centered at 0 with covariance Σ i,j := ρ + 1(i = j) (1 -ρ). Then, we evaluate this density on the points of Z α and renormalize to have a well-defined probability distribution over Z α . We denote by p α,ρ (z) the distribution obtain by this construction. In the top rows of Figures 4 and 6 , the latents are sampled from p α=1,ρ (z) and ρ varies between 0 and 0.99. In the bottom rows of Figures 4 and 6 , the latents are sampled from p α,ρ=0.9 (z) and α varies from 0 to 1.

D.2.2 METRICS

We evaluate disentanglement via the mean correlation coefficient (Hyvärinen & Morioka, 2016; Khemakhem et al., 2020a) which is computed as follows: The Pearson correlation matrix C between the ground-truth features and learned ones is computed. Then, MCC = max π∈permutations 1 m m j=1 |C j,π(j) |. We also evaluate linear equivalence by performing linear regression to predict the ground-truth factors from the learned ones, and report the mean of the Pearson correlations between the ground-truth latents and the learned ones. This metric is known as the coefficient of multiple correlation, R, and turns out to be the square-root of the more widely known coefficient of determination, R 2 . The advantage of using R over R 2 is that we always have MCC ≤ R.

D.2.3 ARCHITECTURE, INNER SOLVER & HYPERPARAMETERS

We use the four-layer convolutional neural network typically used in the disentanglement literature (Locatello et al., 2019) . As mentioned in Section 2.2.2, the norm of the representation f θ (x) must be controlled to make sure the regularization remains effective. To do so, we apply batch normalization (Ioffe & Szegedy, 2015) at the very last layer of the neural network and do not learn its scale and shift parameters. Empirically, we do see the expected behavior that, without any normalization, the norm of f θ (x) explodes as we train, leading to instabilities and low sparsity. In these experiments, the distribution p(y; η) used for learning is a Gaussian with fixed variance. In that case, the inner problem of Section 2.2.2 reduces to Lasso regression. Computing the hypergradient w.r.t. θ requires solving this inner problem. To do so, we use Proximal Coordinate Descent (Tseng, 2001; Richtárik & Takáč, 2014) . In Figures 4 and 6 , we explore various levels of regularization λ. In our implementation of inner-Lasso, λ max := 1 n ∥F ⊤ y∥ ∞ where F ∈ R n×m is the design matrix of the features of the samples of a task, while in the inner-Ridge implementation, λ max := 1 n ∥F ∥ 2 .

D.2.4 EXPERIMENTS VIOLATING ASSUMPTIONS

In this section, we explore variations of the experiments of Section 4, but this time the assumptions of Theorem 1 are violated. Figure 7 shows different degrees of violation of Assumption 7. We consider the cases where S := {{1, 2}, {3, 4}, {5, 6}} (block size = 2), S := {{1, 2, 3}, {4, 5, 6}} (block size = 3) and S := {{1, 2, 3, 4, 5, 6}} (block size = 6). Note that the latter case corresponds to having no sparsity at all in the ground-truth model, i.e. all tasks requires all features. The reader can verify that these three cases indeed violate Assumption 7. In all cases, the distribution p(S) puts uniform mass over its support S. Similarly to the experiments from the main text, w := w ⊙ s, where w ∼ N (0, I) and s ∼ p(S) (s is the binary representation of the set S). Overall, we can see that inner-Lasso does not perform as well when Assumption 7 is violated. For example, when there is no sparsity at all (block size = 6), inner-Lasso performs poorly and is even surpassed by inner-Ridge. Nevertheless, for mild violations (block size = 2), disentanglement (as measured by MCC) remains reasonably high. We further notice that all methods obtain very good R score in all settings. This is expected in light of Theorem 2, which guarantees identifiability up to linear transformation without requiring Assumption 7. Figure 8 presents experiments that are identitical to those of Figure 4 in the main text, except for how w is generated. Here, the components of w are sampled independently according to w i ∼ Laplace(µ = 0, b = 1). We note that, under this process, the probability that w i = 0 is zero. This means all features are useful and Assumption 7 is violated. That being said, due to the fat tail behavior of the Laplacian distribution, many components of w will be close to zero (relatively to its variance). Thus, this can be thought of as a weaker form of sparsity where many features are relatively unimportant. Figure 8 shows that inner-Lasso can still disentangle very well. In fact, the performance is very similar to the experiments that presented actual sparsity (Figure 4 ).

D.2.5 VISUAL EVALUATION

Figures 9 to 12 show how various learned representations respond to changing a single factor of variation in the image (Higgins et al., 2017, Figure 7.A.B) . We see what was expected: the higher the MCC, the more disentangled the learned features appear, thus validating MCC as a good metric for disentanglement. See captions for details. Now, for computing DCI-disentanglement, we normalize each row of the importance matrix I[i, :] by its sum so that it represents a probability distribution. Then disentanglement is given by 1 m × m i 1 -H(I[i, :]), where H denotes the entropy of a distribution. Note that for the desired case of each ground truth latent component being explained by a single inferred latent component, we would have H(I[i, :]) = 0 as we have a one-hot vector for the probability distribution. Similarly, for the case of each ground truth latent component being explained uniformly by all the inferred latents, H(I[i, :]) would be maximized and hence the DCI score would be minimized. To compute the DCI-completeness, we first normalize each column of the importance matrix I[:, j] by its sum so that it represents a probability distribution and then compute 1 m × m i 1 -H(I[:, j]). Figure 13 shows the results for the 3D Shapes experiments (Section 4) with the DCI metric to evaluate disentanglement. Notice that we find the same trend as we had with the MCC metric 4, that inner-Lasso is more robust to correlation between the latent variables, and inner-Ridge + ICA performance drops down significantly with increasing correlation.

D.3 META-LEARNING EXPERIMENTS

Experimental settings. We evaluate the performance of our meta-learning algorithm based on a group-sparse SVM base-learner on the miniImageNet (Vinyals et al., 2016) dataset. Following the standard nomenclature in few-shot classification (Hospedales et al., 2021) with k-shot N -way, where N is the number of classes in each classification task, and k is the number of samples per class in the training dataset D train t , we consider 2 settings: 1-shot 5-way, and 5-shot 5-way. Note that the results presented in Figure 5 only show the performance on 5-shot classification. We use the same residual network architecture as in (Lee et al., 2019) , with 12 layers and a representation of size p = 1.6 × 10 4 . Even though we consider a similar base-learner as MetaOptNet (Lee et al., 2019) (namely, a SVM), our control experiment with λ = 0 cannot be directly compared to the performance of the model reported in that prior work. The reason is that in order to control for any other sources of "effective regularization" (e.g., data augmentation, label smoothing), we do not include the modifications made in MetaOptNet to improve performance. Moreover, we used a different solver (proximal blockcoordinate descent, as opposed to a QP solver) to solve the inner problem Problem (6). Generalization on meta-training tasks. In Section 2.3, we argued that evaluating the performance of the learned representations on meta-training tasks (i.e., tasks similar to the ones seen during metatraining) still shows the generalization capacity to new tasks. Indeed, those new tasks on which we evaluate performance were created using the same classes as the tasks used during meta-training, but using a combination of classes that may have not been seen in any tasks used for optimizing Problem (5). However, evaluation in meta-learning is typically done on meta-test tasks, i.e. tasks based on concepts that were never seen by any task during meta-training. This evaluation requires a stronger notion of generalization, closer to out-of-distribution generalization.

Base-learner

5-way 1-shot 5-way 5-shot SVM (λ = 0) 53.29 ± 0.60% 69.26 ± 0.51% Group-sparse SVM (λ = 0.01) 54.22 ± 0.61% 70.01 ± 0.51% MetaOptNet (Lee et al., 2019) 64.09 ± 0.62% 80.00 ± 0.45% Table 2 : Performance of our meta-learning algorithm on the miniImageNet benchmark. The performance is reported as the mean accuracy and 95% confidence interval on 1000 meta-test tasks. We also report the performance of MetOptNet (Lee et al., 2019) as reference, even though the performance is not directly comparable to our SVM baseline (see text for details). Nonetheless, we observe in Table 2 that the performance of the meta-learning method improves as the base-learners are group-sparse. 



We assume the solution is unique. p(y; η) could be a Gaussian density (regression) or a categorical distribution (classification).



Figure 1: Test performance for the entangled and disentangled representation using Lasso and Ridge regression. All the results are averaged over 10 seeds, with standard error shown in error bars.

Figure 2: Illustration of Assumption 6 showing three examples of distribution P W |S .The red distribution satisfies the assumption, but the blue and orange distributions do not. The red lines are level sets of a Gaussian distribution with full rank covariance. The blue line represents the support of a Gaussian distribution with a low rank covariance. The orange dots represents a distribution with finite support. The green vector a shows that the condition is violated for both the blue and the orange distribution, since, in both cases, W 1,S a = 0 (orthogonal) with probability greater than zero.

Figure 3: The leftmost figure represents S, the support of some p(S). The other figures form a verification that Assumption 7 holds for S.

Figure 4: Disentanglement performance (MCC) for inner-Lasso, inner-Ridge and inner-Ridge combined with ICA as a function of the regularization parameter (left and middle). Varying level of correlation between latents (top) and of noise on the latents (bottom). The right columns shows performance of the best hyperparameter for different values of correlation and noise. et al. (2022), theoretical benefits of disentanglement for transfer learning are not clearly established. Some works have investigated this question empirically and obtained both positive(van Steenkiste et al., 2019;Miladinović et al., 2019;Dittadi et al., 2021) and negative results(Locatello et al., 2019;Montero et al., 2021). Invariant risk minimization(Arjovsky et al., 2020;Ahuja et al., 2020;Krueger et al., 2021;Lu et al., 2021) aims at learning a representation that elicits a task-invariant base-predictor. This differs from our approach which learns base-predictors that are task-specific .

Figure5: Effect of sparsity on the percentage of tasks using specific features, with our meta-learning objective, on miniImageNet (left). The accuracy of the meta-learning algorithm and the average level of sparsity in the base-learners, as λ varies (right).

-truth model W ∈ R k×m Ground-truth coefficients Ŵ ∈ R k×m Learned coefficients θGround-truth parameters of the representation θ Learned parameters of the representationf θ : R d → R m Ground-truth representation f θ : R d → R m Learned representation η ∈ R kParameter of the distribution p(y; η)P W Distribution over ground-truth coefficient matrices W S := {j ∈ [m] | W :j ̸ = 0} (support of W ) P W |S Conditional distribution of W given S. p(S)Ground-truth distribution over possible supports S S Support of the distribution p(S)Optimization W Primal variable Λ Dual variable h * : a → sup b∈R d ⟨a, b⟩ -h(b), Fenchel conjugate of the function h : R d → R f □g : a → min b f (a -b) + g(b), inf-convolution of the functions f and g BST : (a, τ ) → (1 -τ /∥a∥) + a, block soft-thresholding operator A PROOFS OF SECTION 2.1Proposition 1 (MLE Invariance to Invertible Linear Transformations of the Features). Let Ŵ

Assumption 5 ensures that we can construct an invertible matrix U :=

Notice thatP W |S [W :S L S,j ̸ = 0] = 1 -P W |S [W :S L S,j = 0](32)Let N j be the support of L :j , i.e. N j := {i ∈ [m] | L i,j ̸ = 0}. When S ∩ N j = ∅, L S,j = 0 and thus P W |S [W :S L S,j = 0] = 1. When S ∩ N j ̸ = ∅, L S,j ̸ = 0, by Assumption 6 we have that P W |S [W :S L S,j = 0] = 0. Thus

Figure 6: Prediction performance (R Score) for inner-Lasso, inner-Ridge and inner-Ridge combined with ICA as a function of the regularization parameter (left and middle). Varying level of correlation between latents (top) and noise on the latents (bottom). The right columns shows performance of the best hyperparameter for different values of correlation and noise levels.

Figure 7: Disentanglement (MCC, top) and prediction (R Score, bottom) performances for inner-Lasso, inner-Ridge and inner-Ridge combined with ICA as a function of the regularization parameter. The metrics are plotted for multiple value of block size for the support. Block size = 6 corresponds to no sparsity in the ground truth coefficients.

Figure 13: Disentanglement performance (DCI) for inner-Lasso, inner-Ridge and inner-Ridge combined with ICA as a function of the regularization parameter (left and middle). The right columns shows performance of the best hyperparameter for different values of correlation and noise. The top row shows the results for the disentanglement metric of DCI and the bottom row shows the results for the completeness metric of DCI.

Proposition 1 (MLE Invariance to Invertible Linear Transformations of the Features).

Link with meta-learning . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . Table of Notation.

which allows us to apply Theorem 3 to obtain the desired conclusion.B.2 A DISTRIBUTION WITHOUT DENSITY SATISFYING ASSUMPTION 6Interestingly, there are distributions over W 1,S | S that do not have a density w.r.t. the Lebesgue measure, but still satisfy Assumption 6. This is the case, e.g., when W 1,S | S puts uniform mass over a (|S| -1)-dimensional sphere embedded in R |S| and centered at zero. In that case, for all a ∈ R |S| \{0}, the intersection of span{a} ⊥ , which is (|S| -1)-dimensional, with the (|S| -1)dimensional sphere is (|S| -2)-dimensional and thus has probability zero of occurring. One can certainly construct more exotic examples of measures satisfying Assumption 6 that concentrate mass on lower dimensional manifold.

D.2.6 ADDITIONAL METRICS FOR DISENTANGLEMENT

We implemented metrics from the DCI framework (Eastwood & Williams, 2018) to evaluate disentanglement. 1) DCI-Disentanglement: How many ground truth latent components are related to a particular component of the learned latent representation; 2) DCI-Completeness: How many learned latent components are related to a particular component of the ground truth latent representation. Note that for the definition of disentanglement used in the present work Definition 1, we want both DCI-disentanglement and DCI-completeness to be high.The DCI framework requires a matrix of relative importance. In our implementation, this matrix is the coefficient matrix resulting from performing linear regression with inputs as the learned latent representation f θ (x) and targets as the ground truth latent representation f θ (x), and denote the solution as the matrix W . Further, denote by I = |W | as the importance matrix, as I i,j denotes the relevance of inferred latent f θ (x) j for predicting the true latent f θ (x) i .Under Figure 11 : Varying one factor at a time in the image and showing how the learned representation varies in response. This representation was learned with inner-Lasso (best hyperparameter) on a dataset with correlation 0.9 between latents and a noise scale of 1. The corresponding MCC is 0.96. Qualitatively, the representation appears to be well disentangled, but not as well as in Figure 9 (reflected by a drop in MCC of 0.03). 

