FEW-SHOT LEARNING VIA LEARNING THE REPRESENTATION, PROVABLY

Abstract

This paper studies few-shot learning via representation learning, where one uses T source tasks with n 1 data per task to learn a representation in order to reduce the sample complexity of a target task for which there is only n 2 ( n 1 ) data. Specifically, we focus on the setting where there exists a good common representation between source and target, and our goal is to understand how much a sample size reduction is possible. First, we study the setting where this common representation is low-dimensional and provide a risk bound of Õ( dk n1T + k n2 ) on the target task for the linear representation class; here d is the ambient input dimension and k( d) is the dimension of the representation. This result bypasses the Ω( 1 T ) barrier under the i.i.d. task assumption, and can capture the desired property that all n 1 T samples from source tasks can be pooled together for representation learning. We further extend this result to handle a general representation function class and obtain a similar result. Next, we consider the setting where the common representation may be high-dimensional but is capacity-constrained (say in norm); here, we again demonstrate the advantage of representation learning in both high-dimensional linear regression and neural networks, and show that representation learning can fully utilize all n 1 T samples from source tasks. * Alphabetical Order. 1 We only focus on the dependence on T , n1 and n2 in this paragraph. Note that Maurer et al. (2016) only considered n1 = n2, but their approach does not give a better result even if n1 > n2.

1. INTRODUCTION

A popular scheme for few-shot learning, i.e., learning in a data-scarce environment, is representation learning, where one first learns a feature extractor, or representation, e.g., the last layer of a convolutional neural network, from different but related source tasks, and then uses a simple predictor (usually a linear function) on top of this representation in the target task. The hope is that the learned representation captures the common structure across tasks, which makes a linear predictor sufficient for the target task. If the learned representation is good enough, it is possible that a few samples are sufficient for learning the target task, which can be much smaller than the number of samples required to learn the target task from scratch. While representation learning has achieved tremendous success in a variety of applications (Bengio et al., 2013) , its theoretical studies are limited. In existing theoretical work, the most natural algorithm is to explicitly look for the optimal representation given source data, which when combined with a (different) linear predictor on top for each task can achieve the smallest cumulative training error on the source tasks. Of course, it is not guaranteed that the representation found will be useful for the target task unless one makes some assumptions to characterize the connections between different tasks. Existing work often imposes a probabilistic assumption about the connection between tasks: each task is sampled i.i.d. from an underlying distribution. Under this assumption, Maurer et al. (2016) showed an Õ( 1 √ T + 1 √ n2 ) risk bound on the target task, where T is the number of source tasks, n 1 is the number of samples per source task, and n 2 is the number of samples from the target task. 1 Unsatisfactorily, this bound necessarily requires the number of tasks T to be large, and it does not improve when the number of samples per source task, n 1 , increases. Intuitively, one should expect more data to help, and therefore an ideal bound would be 1 √ n1T + 1 √ n2 (or 1 n1T + 1 n2 in the realizable case), because n 1 T is the total number of training data points from source tasks, which can be potentially pooled to learn the representation. Unfortunately, as pointed out by Maurer et al. ( 2016), there exists an example that satisfies the i.i.d. task assumption for which Ω( 1 √ T ) is unavoidable (or Ω( 1T ) in the realizable setting). This means that the i.i.d. assumption alone is not sufficient if we want to take advantage of a large amount of samples per task. Therefore, a natural question is: What connections between tasks enable representation learning to utilize all source data? In this paper, we obtain the first set of results that fully utilize the n 1 T data from source tasks. We replace the i.i.d. assumption over tasks with natural structural conditions on the input distributions and linear predictors. These conditions depict that the target task can be in some sense "covered" by the source tasks, which will further give rise to the desirable guarantees. First, we study the setting where there exists a common well-specified low-dimensional representation in source and target tasks, and obtain an Õ( dk n1T + k n2 ) risk bound on the target task where d is the ambient input dimension, k( d) is the dimension of the representation, and n 2 is the number of data from the target task. Note that this improves the d n2 rate of just learning the target task without using representation learning. The term dk n1T indicates that we can fully exploit all n 1 T data in the source tasks to learn the representation. We further extend this result to handle general representation function class and obtain an Õ( C(Φ) n1T + k n2 ) risk bound on the target task, where Φ is the representation function class and C (Φ) is a certain complexity measure of Φ. Second, we study the setting where there exists a common linear high-dimensional representation for source and target tasks, and obtain an Õ R√ Tr(Σ) √ n1T + R√ Σ 2 √ n2 rate where R is a normalized nuclear norm control over linear predictors, and Σ is the covariance matrix of the raw feature. This also improves over the baseline rate for the case without using representation learning. We further extend this result to two-layer neural networks with ReLU activation. Again, our results indicate that we can fully exploit n 1 T source data. A technical insight coming out of our analysis is that any capacity-controlled method that gets low test error on the source tasks must also get low test error on the target task by virtue of being forced to learn a good representation. Our result on high-dimensional representations shows that the capacity control for representation learning does not have to be through explicit low dimensionality. Organization. The rest of the paper is organized as follows. We review related work in Section 2. In Section 3, we formally describe the setting we consider. In Section 4, we present our main result for low-dimensional linear representation learning. A generalization to nonlinear representation classes is demonstrated in Section 5. In Section 6, we present our main result for high-dimensional linear representation learning. In Section 7, we present our result for representation learning in neural networks. We conclude in Section 8 and leave most of the proofs to appendices.

2. RELATED WORK

The idea of multitask representation learning at least dates back to Caruana (1997) ; Thrun and Pratt (1998); Baxter (2000) . Empirically, representation learning has shown its great power in various domains; see Bengio et al. (2013) for a survey. In particular, representation learning is widely adopted for few-shot learning tasks (Sun et al., 2017; Goyal et al., 2019) . Representation learning is also closely connected to meta-learning (Schaul and Schmidhuber, 2010). Recent work Raghu et al. (2019) empirically suggested that the effectiveness of the popular meta-learning algorithm Model Agnostic Meta-Learning (MAML) is due to its ability to learn a useful representation. The scheme we analyze in this paper is closely related to Lee et al. (2019); Bertinetto et al. (2018) for meta-learning. On the theoretical side, Baxter (2000) performed the first theoretical analysis and gave sample complexity bounds using covering numbers. Maurer et al. (2016) and follow-up work gave analyses on the benefit of representation learning for reducing the sample complexity of the target task. They assumed every task is i.i.d. drawn from an underlying distribution and can obtain an Õ( 1 √ T + 1 √ n2 )

