VARIATIONAL MULTI-TASK LEARNING

Abstract

Multi-task learning aims to improve the overall performance of a set of tasks by leveraging their relatedness. When training data is limited using priors is pivotal, but currently this is done in ad-hoc ways. In this paper, we develop variational multi-task learning -VMTL, a general probabilistic inference framework for simultaneously learning multiple related tasks. We cast multi-task learning as a variational Bayesian inference problem, which enables task relatedness to be explored in a principled way by specifying priors. We introduce Gumbel-softmax priors to condition the prior of each task on related tasks. Each prior is represented as a mixture of variational posteriors of other related tasks and the mixing weights are learned in a data-driven manner for each individual task. The posteriors over representations and classifiers are inferred jointly for all tasks and individual tasks are able to improve their performance by using the shared inductive bias. Experimental results demonstrate that VMTL is able to tackle challenging multi-task learning with limited training data well, and it achieves state-of-the-art performance on four benchmark datasets consistently surpassing previous methods.

1. INTRODUCTION

Multi-task learning (Caruana, 1997 ) is a fundamental learning paradigm for machine learning, which aims to simultaneously solve multiple related tasks to improve the performance of the individual tasks by sharing knowledge. The crux of multi-task learning is how to explore task relatedness (Argyriou et al., 2007; Zhang & Yeung, 2012) , which is non-trivial since the underlying relationship among tasks can be complicated and highly nonlinear. This has been extensively investigated in previous work by learning shared features, designing regularizers imposed on parameters (Pong et al., 2010; Kang et al., 2011; Jawanpuria et al., 2015; Jalali et al., 2010) or exploring priors over parameters (Heskes, 2000; Bakker & Heskes, 2003; Xue et al., 2007; Zhang & Yeung, 2012; Long et al., 2017) . Recently, deep neural networks have been developed learning shared representations in the feature layers while keeping the classifier layers independent (Yang & Hospedales, 2016; Hashimoto et al., 2016; Ruder et al., 2017) . It would be beneficial to learn them jointly by fully leveraging the shared knowledge related tasks, which however remains an open problem. In our work, we consider a particularly challenging setting, where each task contains limited training data. Even more challenging, we have only a handful of related tasks to gain shared knowledge from. This is in stark contrast to few-shot learning (Gordon et al., 2018; Finn et al., 2017; Vinyals et al., 2016) that also suffers from limited data for each task, but usually have a large number of related tasks. Therefore, in our scenario, it is difficult to learn a proper model for each task independently without overfitting (Long et al., 2017; Zhang & Yang, 2017) and it is crucial to leverage the inductive bias (Baxter, 2000) provided by various other related tasks that are learned simultaneously. To do so, we employ the Bayesian framework as it is able to deliver uncertainty estimates on predictions and automatic model regularization (MacKay, 1992; Graves, 2011) , which makes it well suited for multi-task learning with limited training data. The major motivation of our work is to leverage the Bayesian learning framework to handle the great challenges of limited data in multi-task learning. In this paper, we introduce variational multi-task learning -VMTL, a novel variational Bayesian inference approach that can explore task relatedness in a principled way. In order to fully utilize the shared knowledge from related tasks, we explore task relationships in both the feature representation and the classifier by placing prior distributions over them in a Bayesian framework. Thus, multi-task learning is cast as a variational inference problem for feature representations and classifiers jointly. The introduced variational inference allows us to specify the priors by depending on variational pos-1

