HIERARCHICAL GAUSSIAN MIXTURE BASED TASK GENERATIVE MODEL FOR ROBUST META-LEARNING

Abstract

Meta-learning enables quick adaptation of machine learning models to new tasks with limited data. While tasks could come from varying distributions in reality, most of the existing meta-learning methods consider both training and testing tasks as from the same uni-component distribution, overlooking two critical needs of a practical solution: (1) the various sources of tasks may compose a multicomponent mixture distribution, and (2) novel tasks may come from a distribution that is unseen during meta-training. In this paper, we demonstrate these two challenges can be solved jointly by modeling the density of task instances. We develop a meta-training framework underlain by a novel Hierarchical Gaussian Mixture based Task Generative Model (HTGM). HTGM extends the widely used empirical process of sampling tasks to a theoretical model, which learns task embeddings, fits the mixture distribution of tasks, and enables density-based scoring of novel tasks. The framework is agnostic to the encoder and scales well with large backbone networks. The model parameters are learned end-to-end by maximum likelihood estimation via an Expectation-Maximization algorithm. Extensive experiments on benchmark datasets indicate the effectiveness of our method for both sample classification and novel task detection.

1. INTRODUCTION

Training models in small data regimes is of fundamental importance. It demands a model's ability to quickly adapt to new environments and tasks. To compensate for the lack of training data for each task, meta-learning (a.k.a. learning to learn) has become an essential paradigm for model training by generalizing meta-knowledge across tasks (Snell et al., 2017; Finn et al., 2017) . While most existing meta-learning approaches were built upon an assumption that all training/testing tasks are sampled from the same distribution, a more realistic scenario should accommodate training tasks that lie in a mixture of distributions, and testing tasks that may belong to or deviate from the learned distributions. For example, in recent medical research, a global model is typically trained on the historical medical records of a certain set of patients in the database (Shukla & Marlin, 2019; Wu et al., 2021) . However, due to the uniqueness of individuals (e.g., gender, age, genetics), patients' data have a substantial discrepancy, and the pre-trained model may demonstrate significant demographic or geographical biases when testing on a new patient (Purushotham et al., 2017) . This issue can be mitigated by personalized medicine approaches (Chan & Ginsburg, 2011; Ni et al., 2022) where each patient is regarded as a task, and the pre-trained model is fine-tuned (i.e., personalized) on a support set of a few records collected in a short period (e.g., a few weeks) from every patient for adaptation. In this case, the training tasks (i.e., patients) could be sampled from a mixture of distributions (e.g., different age groups), and a testing task may or may not belong to any of the observed groups. As such, a meta-training strategy that is able to fit a mixture of task distributions and identify novel tasks is desirable for making meta-learning a practical solution. One way to tackle the mixture distributions of tasks is to tailor the transferable knowledge to each task by learning a task-specific representation (Oreshkin et al., 2018; Vuorio et al., 2018; Lee & Choi, 2018) , but as discussed in (Yao et al., 2019a) , the over-customized knowledge prevents its generalization among closely related tasks (e.g., tasks from the same distribution). The more recent methods try to balance the generalization and customization of the meta-knowledge by promoting local generalization either among a cluster of related tasks (Yao et al., 2019a) , or within a neighborhood of a meta-knowledge graph of tasks (Yao et al., 2019b) . Neither of them explicitly learns the

