META-LEARNING TRANSFERABLE REPRESENTATIONS WITH A SINGLE TARGET DOMAIN

Abstract

Recent works found that fine-tuning and joint training-two popular approaches for transfer learning-do not always improve accuracy on downstream tasks. First, we aim to understand more about when and why fine-tuning and joint training can be suboptimal or even harmful for transfer learning. We design semi-synthetic datasets where the source task can be solved by either source-specific features or transferable features. We observe that (1) pre-training may not have incentive to learn transferable features and (2) joint training may simultaneously learn sourcespecific features and overfit to the target. Second, to improve over fine-tuning and joint training, we propose Meta Representation Learning (MeRLin) to learn transferable features. MeRLin meta-learns representations by ensuring that a head fit on top of the representations with target training data also performs well on target validation data. We also prove that MeRLin recovers the target ground-truth model with a quadratic neural net parameterization and a source distribution that contains both transferable and source-specific features. On the same distribution, pre-training and joint training provably fail to learn transferable features. MeRLin empirically outperforms previous state-of-the-art transfer learning algorithms on various real-world vision and NLP transfer learning benchmarks.

1. INTRODUCTION

Transfer learning-transferring knowledge learned from a large-scale source dataset to a small target dataset-is an important paradigm in machine learning (Yosinski et al., 2014) with wide applications in vision (Donahue et al., 2014) and natural language processing (NLP) (Howard & Ruder, 2018; Devlin et al., 2019) . Because the source and target tasks are often related, we expect to be able to learn features that are transferable to the target task from the source data. These features may help learn the target task with fewer examples (Long et al., 2015; Tamkin et al., 2020) . Mainstream approaches for transfer learning are fine-tuning and joint training. Fine-tuning initializes from a model pre-trained on a large-scale source task (e.g., ImageNet) and continues training on the target task with a potentially different set of labels (e.g., object recognition (Wang et al., 2017; Yang et al., 2018; Kolesnikov et al., 2019 ), object detection (Girshick et al., 2014) , and segmentation (Long et al., 2015; He et al., 2017) ). Another enormously successful example of fine-tuning is in NLP: pre-training transformers and fine-tuning on downstream tasks leads to state-of-the-art results for many NLP tasks (Devlin et al., 2019; Yang et al., 2019) . In contrast to the two-stage optimization process of fine-tuning, joint training optimizes a linear combination of the objectives of the source and the target tasks (Kokkinos, 2017; Kendall et al., 2017; Liu et al., 2019b) . Despite the pervasiveness of fine-tuning and joint training, recent works uncover that they are not always panaceas for transfer learning. Geirhos et al. (2019) found that the pre-trained models learn the texture of ImageNet, which is biased and not transferable to target tasks. ImageNet pre-training does not necessarily improve accuracy on COCO (He et al., 2018) , fine-grained classification (Kornblith et al., 2019) , and medical imaging tasks (Raghu et al., 2019) . Wu et al. (2020) The goal of this paper is two-fold: (1) to understand more about when and why fine-tuning and joint training can be suboptimal or even harmful for transfer learning; (2) to design algorithms that overcome the drawbacks of fine-tuning and joint training and consistently outperform them. To address the first question, we hypothesize that fine-tuning and joint training do not have incentives to prefer learning transferable features over source-specific features, and thus their capability of learning transferable features is rather accidental depending on the property of the datasets. To empirically analyze the hypothesis, we design a semi-synthetic dataset that contains artificiallyamplified transferable features and source-specific features simultaneously in the source data. Both the transferable and source-specific features can solve the source task, but only transferable features are useful for the target. We analyze what features fine-tuning and joint training will learn. See Figure 1 for an illustration of the semi-synthetic experiments. We observed following failure patterns of fine-tuning and joint training on the semi-synthetic dataset. • Pre-training may learn non-transferable features that don't help the target when both transferable and source-specific features can solve the source task, since it's oblivious to the target data. When the dataset contains source-specific features that are more convenient for neural nets to use, pretraining learns them; as a result, fine-tuning starting from the source-specific features does not lead to improvement. • Joint training learns source-specific features and overfits on the target. A priori, it may appear that the joint training should prefer transferable features because the target data is present in the training loss. However, joint training easily overfits to the target especially when the target dataset is small. When the source-specific features are the most convenient for the source, joint training simultaneously learns the source-specific features and memorizes the target dataset. Toward overcoming the drawbacks of fine-tuning and joint training, we first note that any proposed algorithm, unlike fine-tuning, should use the source and the target simultaneously to encourage extracting shared structures. Second and more importantly, we recall that good representations should enable generalization: we should not only be able to fit a target head with the representations (as joint training does), but the learned head should also generalize well to a held-out target dataset. With this intuition, we propose Meta Representation Learning (MeRLin) to encourage learning transferable and generalizable features: we meta-learn a feature extractor such that the head fit to a target training set performs well on a target validation set. In contrast to the standard model-agnostic meta-learning (MAML) (Finn et al., 2017) , which aims to learn prediction models that are adaptable to multiple target tasks from multiple source tasks, our method meta-learns transferable representations with only one source and one target domain. Empirically, we first verify that MeRLin learns transferable features on the semi-synthetic dataset. We then show that MeRLin outperforms state-of-the-art transfer learning baselines in real-world vision and NLP tasks such as ImageNet to fine-grained classification and language modeling to GLUE. Theoretically, we analyze the mechanism of the improvement brought by MeRLin. In a simple two-layer quadratic neural network setting, we prove that MeRLin recovers the target ground truth with only limited target examples whereas both fine-tuning and joint training fail to learn transferable features that can perform well on the target. In summary, our contributions are as follows. ( 1 

2. SETUP AND PRELIMINARIES

In this paper, we study supervised transfer learning. Consider an input-label pair (x, y) ∈ R d × R. We are provided with a source distributions D s and a target distribution D t over R d × R. The source dataset D s = {x s i , y s i } ns i=1 and the target dataset D t = {x t i , y t i } nt i=1 consist of n s i.i.d. samples from D s and n t i.i.d. samples from D t respectively. Typically n s n t . We view a predictor as a composition of a feature extractor h φ : R d → R m parametrized by φ ∈ Φ, which is often a deep neural net, and a head classifier g θ : R m → R parametrized by θ ∈ Θ, which is often linear. That is, the final prediction is f θ,φ (x) = g θ (h φ (x)). Suppose the loss function is (•, •), such as cross entropy loss for classification tasks. Our goal is to learn an accurate model on the target domain D t . Since the label sets of the source and target tasks can be different, we usually learn two heads for the source task and the target task separately, denoted by θ s and θ t , with a shared feature extractor



observed that large model capacity and discrepancy between the source and target domain eclipse the effect of joint training. Nonetheless, we do not yet have a systematic understanding of what makes the successes of fine-tuning and joint training inconsistent.

) Using a semi-synthetic dataset, we analyze and diagnose when and why fine-tuning and joint training fail to learn transferable representations. (2) We design a meta representation learning algorithm (MeRLin) which outperforms state-of-the-art transfer learning baselines. (3) We rigorously analyze the behavior of fine-tuning, joint training, and MeRLin on a special two-layer neural net setting.

