THE ROLE OF DISENTANGLEMENT IN GENERALISATION

Abstract

Combinatorial generalisation -the ability to understand and produce novel combinations of familiar elements -is a core capacity of human intelligence that current AI systems struggle with. Recently, it has been suggested that learning disentangled representations may help address this problem. It is claimed that such representations should be able to capture the compositional structure of the world which can then be combined to support combinatorial generalisation. In this study, we systematically tested how the degree of disentanglement affects various forms of generalisation, including two forms of combinatorial generalisation that varied in difficulty. We trained three classes of variational autoencoders (VAEs) on two datasets on an unsupervised task by excluding combinations of generative factors during training. At test time we ask the models to reconstruct the missing combinations in order to measure generalisation performance. Irrespective of the degree of disentanglement, we found that the models supported only weak combinatorial generalisation. We obtained the same outcome when we directly input perfectly disentangled representations as the latents, and when we tested a model on a more complex task that explicitly required independent generative factors to be controlled. While learning disentangled representations does improve interpretability and sample efficiency in some downstream tasks, our results suggest that they are not sufficient for supporting more difficult forms of generalisation.

1. INTRODUCTION

Generalisation to unseen data has been a key challenge for neural networks since the early days of connectionism, with considerable debate about whether these models can emulate the kinds of behaviours that are present in humans (McClelland et al., 1986; Fodor & Pylyshyn, 1988; Smolensky, 1987; 1988; Fodor & McLaughlin, 1990) . While the modern successes of Deep Learning do indeed point to impressive gains in this regard, human level generalisation still remains elusive (Lake & Baroni, 2018; Marcus, 2018) . One explanation for this is that humans encode stimuli in a compositional manner, with a small set of independent and more primitive features (e.g., separate representations of size, position, line orientation, etc.) being used to build more complex representation (e.g., a square of a given size and position). The meaning of the more complex representation comes from the meaning of it's parts. Critically, compositional representations afford the ability to recombine primitives in novel ways: if a person has learnt to recognize squares and circles in context where all squares are blue and all circles are red, they can nevertheless also recognise red squares, even though they have never seen these in the training data. This ability to perform combinatorial generalisation based on compositional representations is thought to be a hallmark of human level intelligence (Fodor & Pylyshyn, 1988) (See McClelland et al. (1986) for a diverging opinion). Recently it has been proposed that generalisation in neural networks can be improved by extracting disentangled representations (Higgins et al., 2017) from data using (variational) generative models (Kingma & Welling, 2013; Rezende et al., 2014) . In this view, disentangled representations capture the compositional structure of the world (Higgins et al., 2018a; Duan et al., 2020) , separating the generative factors present in the stimuli into separate components of the internal representation (Higgins et al., 2017; Burgess et al., 2018) . It has been argued that these representations allow downstream models to perform better due to the structured nature of the representations (Higgins et al., 2017; 2018b) and to share information across related tasks (Bengio et al., 2014) . Here we are interested in the question of whether networks can support combinatorial generalisation and extrapolation by exploiting these disentangled representations. In this study we systematically tested whether and how disentangled representations support three forms of generalisation: two forms of combinatorial generalisation that varied in difficulty as well as extrapolation, as detailed below. We explored this issue by assessing how well models could render images when we varied (1) the image datasets (dSprites and 3DShape), (2) the models used to reconstruct these images (β-VAEs and FactorVAEs with different disentanglement pressures, and decoder models in which we dropped the encoders and directly input perfectly disentangled latents), and (3) the tasks that varied in their combinatorial requirements (image reconstruction vs. image transformation). Across all conditions we found that models only supported the simplest versions of combinatorial generalisation and the degree of disentanglement had no impact on the degree of generalisation. These findings suggest that models with entangled and disentangled representations are both generalising on the basis of overall similarity of the trained and test images (interpolation), and that combinatorial generalisation requires more than learning disentangled representations.

1.1. PREVIOUS WORK

Recent work on learning disentangled representations in unsupervised generative models has indeed shown some promise in improving the performance of downstream tasks (Higgins et al., 2018b; van Steenkiste et al., 2019) but this benefit is mainly related to sample efficiency rather than generalisation. Indeed, we are only aware of two studies that have considered the importance of learned disentanglement for combinatorial generalisation and they have used different network architectures and have reached opposite conclusions. Bowers et al. (2016) showed that a recurrent model of shortterm memory tested on lists of words that required some degree of combinatorial generalisation (recalling a sequence of words when one or more of words at test were novel) only succeeded when it had learned highly selective (disentangled) representations ("grandmother cell" units for letters). By contrast, Chaabouni et al. (2020) found that models with disentangled representations do not confer significant improvements in generalisation over entangled ones in a language modeling setting, with both entangled and disentangled representations supporting combinatorial generalisation as long as the training set was rich enough. At the same time, they found that languages generated through compositional representations were easier to learn, suggesting this as a pressure to learn disentangled representations. A number of recent papers have reported that VAEs can support some degree of combinatorial generalisation, but there is no clear understanding of whether and how disentangled representations played any role in supporting this performance. Esmaeili et al. (2019) showed that a model trained on the MNIST dataset could reconstruct images even when some particular combination of factors were removed during training, such as a thick number 7 or a narrow 0. The authors also showed that the model had learned disentangled representations and concluded that the disentangled representations played a role in the successful performance. However, the authors did not vary the degree of disentanglement in their models and, accordingly, it is possible that a VAE that learned entangled representations would do just as well. Similarly, Higgins et al. (2018c) have highlighted how VAEs that learn disentangled representations can support some forms of combinatorial generalisation when generating images from text. For example, their model could render a room with white walls, pink floor and blue ceiling even though it was never shown that combination in the training set. This is an impressive form of combinatorial generalisation but, as we show below, truly compositional representations should be able to support several other forms of combinatorial generalisations that were not tested in this study. Moreover, it is not clear what role disentanglement played in this successful instance of generalisation. Finally, Zhao et al. (2018) assessed VAE performance on a range of combinatorial generalisation tasks that varied in difficulty, and found that the model performed well in the simplest settings but struggled in more difficult ones. But again, they did not consider whether learning disentangled representations was relevant to generalisation performance. Another work that has significant relation to ours is Locatello et al. (2019) , who examine how hard it is to learn disentangled representations and their relation to sampling efficiency for downstream tasks. We are interested in a related, but different question: even if a model learns a disentangled representation in an intermediate layer, does this enable models to achieve combinatorial generali-

