ON LINEAR IDENTIFIABILITY OF LEARNED REPRE-SENTATIONS

Abstract

Identifiability is a desirable property of a statistical model: it implies that the true model parameters may be estimated to any desired precision, given sufficient computational resources and data. We study identifiability in the context of representation learning: discovering nonlinear data representations that are optimal with respect to some downstream task. When parameterized as deep neural networks, such representation functions lack identifiability in parameter space, because they are overparameterized by design. In this paper, building on recent advances in nonlinear Independent Components Analysis, we aim to rehabilitate identifiability by showing that a large family of discriminative models are in fact identifiable in function space, up to a linear indeterminacy. Many models for representation learning in a wide variety of domains have been identifiable in this sense, including text, images and audio, state-of-the-art at time of publication. We derive sufficient conditions for linear identifiability and provide empirical support for the result on both simulated and real-world data.

1. INTRODUCTION

An increasingly common methodology in machine learning is to improve performance on a primary down-stream task by first learning a high-dimensional representation of the data on a related, proxy task. In this paradigm, training a model reduces to fine-tuning the learned representations for optimal performance on a particular sub-task (Erhan et al., 2010) . Deep neural networks (DNNs), as flexible function approximators, have been surprisingly successful in discovering effective high-dimensional representations for use in downstream tasks such as image classification (Sharif Razavian et al., 2014 ), text generation (Radford et al., 2018; Devlin et al., 2018) , and sequential decision making (Oord et al., 2018) . When learning representations for downstream tasks, it would be useful if the representations were reproducible, in the sense that every time a network relearns the representation function on the same data distribution, they were approximately the same, regardless of small deviations in the initialization of the parameters or the optimization procedure. In some applications, such as learning real-world causal relationships from data, such reproducible learned representations are crucial for accurate and robust inference (Johansson et al., 2016; Louizos et al., 2017) . A rigorous way to achieve reproducibility is to choose a model whose representation function is identifiable in function space. Informally speaking, identifiability in function space is achieved when, in the limit of infinite data, there exists a single, global optimum in function space. Interestingly, Figure 1 exhibits learned representation functions that appear to be the same up to a linear transformation, even on finite data and optimized without convergence guarantees (see Appendix A.1 for training details). In this paper, we account for Figure 1 by making precise the relationship it exemplifies. We prove that a large class of discriminative and autoregressive models are identifiable in function space, up to a linear transformation. Our results extend recent advances in the theory of nonlinear Independent Components Analysis (ICA), which have recently provided strong identifiability results for generative models of data (Hyvärinen et al., 2018; Khemakhem et al., 2019; 2020; Sorrenson et al., 2020) . Our key contribution is to bridge the gap between these results and discriminative models, commonly used for representation learning (e.g., (Hénaff et al., 2019; Brown et al., 2020) ). The rest of the paper is organized as follows. In Section 2, we describe a general discriminative model family, defined by its canonical mathematical form, which generalizes many supervised, self- supervised, and contrastive learning frameworks. In Section 3, we prove that learned representations in this family have an asymptotic property desirable for representation learning: equality up to a linear transformation. In Section 4, we show that this family includes a number of highly performant models, state-of-the-art at publication for their problem domains, including CPC (Oord et al., 2018) , BERT (Devlin et al., 2018), and GPT-2 and GPT-3 (Radford et al., 2018; 2019; Brown et al., 2020) . Section 5 investigates the actually realizable regime of finite data and partial optimization, showing that representations learned by members of the identifiable model family approach equality up to a linear transformation as a function of dataset size, neural network capacity, and optimization progress.

2. MODEL FAMILY AND DATA DISTRIBUTION

The learned embeddings of a DNN are a function not only of the parameters, but also the network architecture and size of dataset (viewed as a sample from the underlying data distribution). This renders any analysis in full generality challenging. To make such an analysis tractable, in this section, we begin by specifying a set of assumptions about the underlying data distribution and model family that must hold for the learned representations to be similar up to a linear transformation. These assumptions are, in fact, satisfied by a number of already published, highly performant models. We establish definitions in this section, and discuss these existing approaches in depth in Section 4.

Data Distribution

We assume the existence of a generalized dataset in the form of an empirical distribution p D (x, y, S) over random variables x, y and S with the following properties: • The random variable x is an input variable, typically high-dimensional, such as text or an image. • The random variable y is a target variable whose value the model predicts. In case of object classification, this would be some semantically meaningful class label. However, in our model family, y may also be a high-dimensional context variable, such a text, image, or sentence fragment. • S is a set containing the possible values of y given x, so p D (y|x, S) > 0 ⇐⇒ y ∈ S. Note that the set of labels S is not fixed, but a random variable. This allows supervised, contrastive, and self-supervised learning frameworks to be analyzed together: the meaning of S encodes the task. For supervised classification, S is deterministic and contains class labels. For self-supervised pretraining, S contains randomly-sampled high-dimensional variables such as image embeddings. For deep metric learning (Hoffer and Ailon, 2015; Sohn, 2016) , the set S contains one positive and k negative samples of the class to which x belongs. Canonical Discriminative Form Given a data distribution as above, a generalized discriminative model family may be defined by its parameterization of the probability of a target variable y conditioned on an observed variable x and a set S that contains not only the true target label y, but



Figure 1: Left and Middle: Two learned DNN representation functions f θ1 (B), f θ2 (B) visualized on held-out data B. The DNNs are word embedding models Mnih and Teh (2012) trained on the Billion Word Dataset (Chelba et al., 2013) (see Appendix A.1 for code release and training details). Right: Af θ1 (B) and f θ2 (B), where A is a linear transformation learned after training. The overlap exhibits linear identifiability (see Section 3): different representation functions, learned on the same data distribution, live within linear transformations of each other in function space.

