CURVED DATA REPRESENTATIONS IN DEEP LEARNING

Abstract

The phenomenal success of deep neural networks inspire many to understand the inner mechanisms of these models. To this end, several works have been studying geometric properties such as the intrinsic dimension of latent data representations produced by the layers of the network. In this paper, we investigate the curvature of data manifolds, i.e., the deviation of the manifold from being flat in its principal directions. We find that state-of-the-art trained convolutional neural networks have a characteristic curvature profile along layers: an initial increase, followed by a long phase of a plateau, and tailed by another increase. In contrast, untrained networks exhibit qualitatively and quantitatively different curvature profiles. We also show that the curvature gap between the last two layers is strongly correlated with the performance of the network. Further, we find that the intrinsic dimension of latent data along the network layers is not necessarily indicative of curvature. Finally, we evaluate the effect of common regularizers such as weight decay and mixup on curvature, and we find that mixup-based methods flatten intermediate layers, whereas the final layers still feature high curvatures. Our results indicate that relatively flat manifolds which transform to highly-curved manifolds toward the last layers generalize well to unseen data.

1. INTRODUCTION

Real-world data arising from scientific and engineering problems is often high-dimensional and complex. Using such data for downstream tasks may seem hopeless at first glance. Nevertheless, the widely accepted manifold hypothesis (Cayton, 2005) stating that complex high-dimensional data is intrinsically low-dimensional, suggests that not all hope is lost. Indeed, significant efforts in machine learning have been dedicated to developing tools for extracting meaningful low-dimensional features from real-world information (Khalid et al., 2014; Bengio et al., 2013) . Particularly successful in several challenging tasks such as classification (Krizhevsky et al., 2017) and recognition (Girshick et al., 2014) are deep learning approaches which manipulate data via nonlinear neural networks. Unfortunately, the inner mechanisms of deep models are not well understood at large. Motivated by the manifold hypothesis and more generally, manifold learning (Belkin & Niyogi, 2003) , several recent approaches proposed to analyze deep models by their latent representations. Essentially, a manifold is a topological space locally similar to an Euclidean domain at each of its points (Lee, 2013) . A key property of a manifold is its intrinsic dimension, defined as the dimension of the related Euclidean domain. Recent studies estimated the intrinsic dimension (ID) along layers of trained neural networks using neighborhood information (Ansuini et al., 2019) and topological data analysis (Birdal et al., 2021) . Remarkably, it has been shown that the ID admits a characteristic "hunchback" profile (Ansuini et al., 2019) , i.e., it increases in the first layers and then it decreases progressively. Moreover, the ID was found to be strongly correlated with the network performance. Still, the intrinsic dimension is only a single measure, providing limited knowledge of the manifold. To consider other properties, the manifold has to be equipped with an additional structure. In this work, we focus on Riemannian manifolds which are differentiable manifolds with an inner product (Lee, 2006) . Riemannian manifolds can be described using properties such as angles, distances, and curvatures. For instance, the curvature in two dimensions is the amount by which a surface deviates from being a plane, which is completely flat. Ansuini et al. (2019) conjectured that while the intrinsic dimension decreases with network depth, the underlying manifold is highly curved. Our study confirms the latter conjecture empirically by estimating the principal curvatures of latent representations of popular deep convolutional classification models trained on benchmark datasets.

