SYNERGIES BETWEEN DISENTANGLEMENT AND SPARSITY: A MULTI-TASK LEARNING PERSPECTIVE

Abstract

Although disentangled representations are often said to be beneficial for downstream tasks, current empirical and theoretical understanding is limited. In this work, we provide evidence that disentangled representations coupled with sparse base-predictors improve generalization. In the context of multi-task learning, we prove a new identifiability result that provides conditions under which maximally sparse base-predictors yield disentangled representations. Motivated by this theoretical result, we propose a practical approach to learn disentangled representations based on a sparsity-promoting bi-level optimization problem. Finally, we explore a meta-learning version of this algorithm based on group Lasso multiclass SVM base-predictors, for which we derive a tractable dual formulation. It obtains competitive results on standard few-shot classification benchmarks, while each task is using only a fraction of the learned representations.

1. INTRODUCTION

The recent literature on self-supervised learning has provided evidence that learning a representation on large corpuses of data can yield strong performances on a wide variety of downstream tasks (Devlin et al., 2018; Chen et al., 2020) , especially in few-shot learning scenarios where the training data for these tasks is limited (Brown et al., 2020b; Dosovitskiy et al., 2021; Radford et al., 2021) . Beyond transferring across multiple tasks, these learned representations also lead to improved robustness against distribution shifts (Wortsman et al., 2022) as well as stunning text-conditioned image generation (Ramesh et al., 2022) . However, preliminary assessments of the latter has highlighted shortcomings related to compositionality (Marcus et al., 2022) , suggesting new algorithmic innovations are needed to make further progress. Another line of work has argued for the integration of ideas from causality to make progress towards more robust and transferable machine learning systems (Pearl, 2019; Schölkopf, 2019; Goyal & Bengio, 2022) . Causal representation learning has emerged recently as a field aiming to define and learn representations suited for causal reasoning (Schölkopf et al., 2021) . This set of ideas is strongly related to learning disentangled representations (Bengio et al., 2013) . Informally, a representation is considered disentangled when its components are in one-to-one correspondence with natural and interpretable factors of variations, such as object positions, colors or shape. Although a plethora of works have investigated theoretically under which conditions disentanglement is possible (Hyvärinen & Morioka, 2016; 2017; Hyvärinen et al., 2019; Khemakhem et al., 2020a; Locatello et al., 2020a; Klindt et al., 2021; Von Kügelgen et al., 2021; Gresele et al., 2021; Lachapelle et al., 2022; Lippe et al., 2022b; Ahuja et al., 2022c) , fewer works have tackled how a disentangled representation could be beneficial for downstream tasks. Those who did mainly provide empirical rather than theoretical evidence for or against its usefulness (Locatello et al., 2019; van Steenkiste et al., 2019; Miladinović et al., 2019; Dittadi et al., 2021; Montero et al., 2021) . In this work, we explore synergies between disentanglement and sparse base-predictors in the context of multi-task learning. At the heart of our contributions is the assumption that only a small subset of all factors of variations are useful for each downstream task, and this subset might change from one task to another. We will refer to such tasks as sparse tasks, and their corresponding sets of useful factors as their supports. This assumption was initially suggested by Bengio et al. (2013, Section 3.5): "the feature set being trained may be destined to be used in multiple tasks that may have distinct [and unknown] subsets of relevant features. Considerations such as these lead us to the conclusion that the most robust approach to feature learning is to disentangle as many factors as possible, discarding as little information about the data as is practical". This strategy is very much in line with the current self-supervised learning trend (Radford et al., 2021) , except for its focus on disentanglement. Our main contributions are the following: (i) We formalize this "sparse task assumption" and argue theoretically and empirically how, in this context, disentangled representations coupled with sparsity-regularized base-predictors can obtain better generalization than their entangled counterparts (Section 2.1). (ii) We introduce a novel identifiability result (Theorem 1) which shows how one can leverage multiple sparse tasks to learn a shared disentangled representation by regularizing the task-specific predictors to be maximally sparse (Section 2.2.1). Crucially, Assumption 7 formalizes how diverse the task supports have to be in order to guarantee disentanglement. (iii) Motivated by this result, we propose a tractable bi-level optimization (Problem (4)) to learn the shared representation while regularizing the task-specific base-predictors to be sparse (Section 2.2.2). We validate our theory by showing our approach can indeed disentangle latent factors on tasks constructed from the 3D Shapes dataset (Burgess & Kim, 2018) . (iv) Finally, we draw a connection between this bi-level optimization problem and some formulations from the meta-learning literature (Section 2.3). Inspired by our identifiability result, we enhance an existing method (Lee et al., 2019) , where the base-learners are now group-sparse SVMs. We show that this new meta-learning algorithm achieves competitive performance on the miniImageNet benchmark (Vinyals et al., 2016) , while only using a fraction of the learned representation.

2. SYNERGIES BETWEEN DISENTANGLEMENT AND SPARSITY

In this section, we formally introduce the notion of entangled and disentangled representations. First, we assume the existence of some ground-truth encoder function f θ : R d → R m that maps observations x ∈ X ⊆ R d , e.g., images, to its corresponding interpretable and usually lower dimensional representation f θ (x) ∈ R m , m ≤ d. The exact form of this ground-truth encoder depends on the task at hand, but also on what the machine learning practitioner considers as interpretable. The learned encoder function is denoted by f θ : R d → R m , and should not be conflated with the groundtruth representation f θ . For example, f θ can be parametrized by a neural network. Throughout, we are going to use the following definition of disentanglement. Definition 1 (Disentangled Representation, Khemakhem et al. 2020a; Lachapelle et al. 2022) . A learned encoder function f θ : R d → R m is said to be disentangled w.r.t. the ground-truth representation f θ when there exists an invertible diagonal matrix D and a permutation matrix P such that, for all x ∈ X , f θ (x) = DP f θ (x). Otherwise the encoder f θ is said to be entangled. Intuitively, a representation is disentangled when there is a one-to-one correspondence between its components and the components of the ground-truth representation, up to rescaling. Note that there exist less stringent notions of disentanglement which allow for component-wise nonlinear invertible transformations of the factors (Hyvärinen & Morioka, 2017; Hyvärinen et al., 2019) . Notation. Capital bold letters denote matrices and lower case bold letters denote vectors. The set of integers from 1 to n is denoted by [n] . We write ∥•∥ for the Euclidean norm on vectors and the Frobenius norm on matrices. For a matrix A ∈ R k×m , ∥A∥ 2,1 = m j=1 ∥A :j ∥, and ∥A∥ 2,0 = m j=1 1 ∥A:j ∦ =0 , where 1 is the indicator function. The ground-truth parameter of the encoder function is θ, while that of the learned representation is θ. We follow this convention for all the parameters throughout. Table 1 in Appendix A summarizes all the notation.

2.1. DISENTANGLEMENT AND SPARSE BASE-PREDICTORS FOR IMPROVED GENERALIZATION

In this section, we compare the generalization performance of entangled and disentangled representations on sparse downstream tasks. We show that the maximum likelihood estimator (defined in Problem (1)) computed on linearly equivalent representations (entangled or disentangled) yield the same model (Proposition 1). However, disentangled representations have better generalization properties when combined with a sparse base-predictor (Proposition 2 and Figure 1 ). First, the learned representation f θ is assumed to be linearly equivalent to the ground-truth representation f θ , i.e. there exists an invertible matrix L such that, for all x ∈ X , f θ (x) = Lf θ (x).

