HIERARCHICAL GAUSSIAN MIXTURE BASED TASK GENERATIVE MODEL FOR ROBUST META-LEARNING

Abstract

Meta-learning enables quick adaptation of machine learning models to new tasks with limited data. While tasks could come from varying distributions in reality, most of the existing meta-learning methods consider both training and testing tasks as from the same uni-component distribution, overlooking two critical needs of a practical solution: (1) the various sources of tasks may compose a multicomponent mixture distribution, and (2) novel tasks may come from a distribution that is unseen during meta-training. In this paper, we demonstrate these two challenges can be solved jointly by modeling the density of task instances. We develop a meta-training framework underlain by a novel Hierarchical Gaussian Mixture based Task Generative Model (HTGM). HTGM extends the widely used empirical process of sampling tasks to a theoretical model, which learns task embeddings, fits the mixture distribution of tasks, and enables density-based scoring of novel tasks. The framework is agnostic to the encoder and scales well with large backbone networks. The model parameters are learned end-to-end by maximum likelihood estimation via an Expectation-Maximization algorithm. Extensive experiments on benchmark datasets indicate the effectiveness of our method for both sample classification and novel task detection.

1. INTRODUCTION

Training models in small data regimes is of fundamental importance. It demands a model's ability to quickly adapt to new environments and tasks. To compensate for the lack of training data for each task, meta-learning (a.k.a. learning to learn) has become an essential paradigm for model training by generalizing meta-knowledge across tasks (Snell et al., 2017; Finn et al., 2017) . While most existing meta-learning approaches were built upon an assumption that all training/testing tasks are sampled from the same distribution, a more realistic scenario should accommodate training tasks that lie in a mixture of distributions, and testing tasks that may belong to or deviate from the learned distributions. For example, in recent medical research, a global model is typically trained on the historical medical records of a certain set of patients in the database (Shukla & Marlin, 2019; Wu et al., 2021) . However, due to the uniqueness of individuals (e.g., gender, age, genetics), patients' data have a substantial discrepancy, and the pre-trained model may demonstrate significant demographic or geographical biases when testing on a new patient (Purushotham et al., 2017) . This issue can be mitigated by personalized medicine approaches (Chan & Ginsburg, 2011; Ni et al., 2022) where each patient is regarded as a task, and the pre-trained model is fine-tuned (i.e., personalized) on a support set of a few records collected in a short period (e.g., a few weeks) from every patient for adaptation. In this case, the training tasks (i.e., patients) could be sampled from a mixture of distributions (e.g., different age groups), and a testing task may or may not belong to any of the observed groups. As such, a meta-training strategy that is able to fit a mixture of task distributions and identify novel tasks is desirable for making meta-learning a practical solution. One way to tackle the mixture distributions of tasks is to tailor the transferable knowledge to each task by learning a task-specific representation (Oreshkin et al., 2018; Vuorio et al., 2018; Lee & Choi, 2018) , but as discussed in (Yao et al., 2019a) , the over-customized knowledge prevents its generalization among closely related tasks (e.g., tasks from the same distribution). The more recent methods try to balance the generalization and customization of the meta-knowledge by promoting local generalization either among a cluster of related tasks (Yao et al., 2019a) , or within a neighborhood of a meta-knowledge graph of tasks (Yao et al., 2019b) . Neither of them explicitly learns the underlying distribution from which the tasks are generated, rendering them infeasible for detecting novel tasks that are out-of-distribution. However, detecting novel tasks is crucial in high-stake domains, such as medicine and finance, which provides users (e.g., physicians) confidence on whether to trust the results of a testing task or not, and facilitates the downstream decision-making. In (Lee et al., 2019a) , a task-specific tuning variable was introduced to modulate the initial parameters learned by MAML (Finn et al., 2017) , so that the impacts of the meta-knowledge on different tasks are adjusted differently, e.g., novel tasks receive less impact than known tasks do. Whereas, this method focuses on improving model performance on different tasks (either known or novel), but neglects the critical mission of detecting which tasks are novel. In practice, providing an unreliable accuracy on a novel task, without differentiating it from other tasks may be meaningless and risky. Since the aforementioned methods cannot simultaneously handle the mixture distribution of tasks and novel tasks, a practical solution is in demand. In this work, we consider tasks as instances, and demonstrate the dual problem of modeling the mixture of task distributions and detecting novel tasks are two sides of the same coin, i.e., density estimation on task instances. To this end, we propose a new Hierarchical Gaussian Mixture based Task Generative Model (HTGM) to explicitly model the generative process of task instances. Our contributions are summarized as follows. • For the first time, the widely used empirical process of generating a task is theoretically extended to and specified by a hierarchy of Gaussian mixture (GM) distributions. HTGM generates a task embedding from a task-level GM, and uses it to define the task-conditioned mixture probabilities for a class-level GM, from which samples are drawn, for instantiating the generated task. To allow realistic classes per task, a new Gibbs distribution is proposed to underlie the class-level GM. • HTGM is an encoder-agnostic framework, thus is flexible to different domains. It inherits metricbased meta-learning methods, and only introduces a small overhead to an encoder for parameterizing its distributions, thus is efficient, and enables large-scale backbone networks. The model parameters are learned end-to-end by maximum likelihood estimation via a principled Expectation-Maximization (EM) algorithm. The bounds of our likelihood function is theoretically analyzed. • In the experiments, we evaluated HTGM on benchmark image datasets for validating its ability to take advantage of large backbone networks, its effectiveness in modeling the mixture distribution of tasks, and its usefulness in identifying novel tasks. The results demonstrate HTGM outperforms the state-of-the-art (SOTA) baselines with significant improvements in most cases.

2. RELATED WORK

To the best of our knowledge, this is the first work to explicitly model the generative process of task instances from a mixture of distributions for meta-learning with novel task detection. Meta-learning aims to handle the few-shot learning problem, which derives memory-based (Mishra et al., 2018 ), optimization-based (Finn et al., 2017; Li et al., 2017), and metric-based methods (Vinyals et al., 2016; Snell et al., 2017) , which often consider an artificial scenario where training/test tasks are sampled from the same distribution. To enable more varying tasks, task-adaptive methods facilitates the customization of meta-knowledge by learning task-specific parameters (Rusu et al., 2018; Lee & Choi, 2018) , temperature scaling parameters (Oreshkin et al., 2018) , and task-specific modulation on model initialization (Vuorio et al., 2018; Yao et al., 2019a; b; Lee et al., 2019a) . Among them, there are methods tackling the mixture distribution of tasks by clustering tasks (Yao et al., 2019a) or learning task graphs (Yao et al., 2019b) , and method relocating the initial parameters for different tasks so that they use the meta-knowledge differently (Lee et al., 2019a) . As discussed before, none of these methods jointly handle the mixture of task distributions and the detection of novel tasks. Our model is built upon metric-based methods, and learns task embeddings for modeling task distributions. Achille et al. ( 2019) also proposed to learn embeddings for tasks and introduced a metalearning method, but not for few-shot learning. Its embeddings are from a pre-specified set of tasks (rather than episode-wise sampling), and the meta-learning framework is for model selection. The model in (Yao et al., 2019a) has an augmented encoder for task embedding, but it does not explicitly model task generation, and is not designed for novel task detection (empirical comparison in 4.1). Conventional novelty detection aims to identify and reject samples from unseen classes (Cheng & Vasconcelos, 2021) . It relates to open-set recognition (Vaze et al., 2022) , which aims to simultaneously identify unknown samples and classify samples from known classes. Out-of-distribution

