FP AINET: FUSION PROTOTYPE WITH ADAPTIVE INDUCTION NETWORK FOR FEW-SHOT LEARNING Anonymous authors Paper under double-blind review

Abstract

A prototypical network treats all samples equally and does not consider the noisy samples, which leads to a biased class representation. In this paper, we propose a novel fusion prototype with an adaptive induction network (FP AINet) for fewshot learning that can learn representative prototypes from a few support samples. Specifically, to address the problem of noisy samples, an adaptive induction network is developed, which can learn different class representations for queries and assign adaptive scores for support samples according to their relative significance. Moreover, FP AINet can generate a more accurate prototype than comparison methods by considering the query-related samples. With an increasing number of samples, the prototypical network is more expressive since the adaptive induction network ignores the relative local features. As a result, a Gaussian fusion algorithm is designed to learn more representative prototypes. Extensive experiments are conducted on three datasets: miniImageNet, tieredImageNet, and CIFAR FS. The experimental results compared with the state-of-the-art few-shot learning methods demonstrate the superiority of FP AINet.

1. INTRODUCTION

Few-shot learning aims to learn classifiers for novel classes with limited data. Prototypical network (PN) (Snell et al. (2017) ) averages the support features as the prototype. While most of the previous research has achieved promising results, those methods generally assume that the samples used for training were carefully selected to represent their class. The expected prototype should have the smallest distance from all other samples in its class (Liu et al. (2020) ), and each sample significantly contributes to the final performance when training from a few labeled samples. Unfortunately, the existing dataset frequently contains mislabeled samples because of weakly automated supervised annotation, ambiguity, or human error (Liang et al. (2022) ). In addition, since some images have multiple objects and unrelated background information, the accuracy can be affected by a single noisy example. As illustrated in Figure 1 (a), the PN is easily affected by noisy samples. Metalearning approaches have become the dominant paradigm for few-shot learning (Chen et al. (2020) ; Tian et al. (2020) ; Yao et al. (2021) ). Meta-learning approaches can be roughly summarized into two categories: optimization-based methods (Antoniou et al. (2019) ; Kao et al. (2022) ) and metric-based methods (Vinyals et al. (2016) ; Sung et al. (2018) ). Optimization-based methods readily learn the model's parameters to adapt to each task using gradient descent. However, these methods need to be fine-tuned for the target tasks. Metric-based methods are more efficient and applicable than optimization-based methods. Metricbased methods learn a good metric to calculate the similarity between query and the support samples using a pre-defined distance function, such as cosine similarity (Vinyals et al. (2016) ), euclidean distance (Snell et al. (2017) ; Koch et al. (2015) ), earth mover's distance (Zhang et al. (2020) ), or a distance parameterized by a neural network (Sung et al. (2018) ; Zhang et al. (2018) ), which has achieved remarkable success due to its fewer parameters. To obtain more representative prototypes, many methods correct the prototype by using similar samples (Yang et al. (2021) ; Liu et al. (2020) ) or additional knowledge (Zhang et al. (2021) ), but since it is easy to introduce sample noise or class differences, a novel method of fusion prototype with an adaptive induction network (FP AINet) is proposed to solve the issue. The induction network (Geng et al. (2019) ) designs a non-linear mapping from sample vector to class vector to diminish the prototype bias. But since the model has not seen query samples before extracting support features, some inappropriate features may be extracted, resulting in a significant deviation in prototype estimation. An adaptive induction network (AINet) is proposed to extract more reliable prototypes for each class. The AINet does not take into account the local relative importance of different regions in a sample, while the prototype generated by the PN becomes more discriminative and expressive as the number of support samples increases, as shown in Figure 1 (b). To solve the problem that the calculation of a single prototype is not comprehensive, we assume the estimated prototype follow a multivariate Gaussian distribution (Zhang et al. (2021) ). Specifically, the features in the target task are transformed using the Yeo-Johnson transformation, and then two kinds of prototypes are combined, which are generated by AINet and PN, respectively. Finally, the performance of FP AINet is evaluated on the miniImageNet, tieredImageNet, and CIFAR FS. Besides, the ablation experiments validate the effectiveness of the FP AINet. Experimental results show that the FP AINet can generate a more representative prototype and improve the accuracy of few-shot learning. The main contributions are summarized as follows: (1) A novel method of AINet is proposed to assign scores to support samples based on their relevance automatically. (2) A modified Gaussian-based fusion algorithm is employed to aggregates prototypes from PN and AINet by exploring the unlabeled samples. (3) Extensive experiments on three datasets demonstrate the effectiveness of the FP AINet. Metric-based methods. To boost the performance of PN, task dependent adaptive metric (TADAM) (Oreshkin et al. (2018) ) proposes metric scaling and task conditioning. It is difficult to represent the distribution of a class with limited samples, so many methods have been proposed to correct bias in prototype estimations (Hou & Sato (2021) ; Yang et al. (2021) ). BD-CSPN (Liu et al. (2020) ) modifies prototypes by diminishing intra-class and cross-class bias. A pseudo-label is used to reduce intra-class bias, but it is easy to introduce noise. Rather than relying on a pre-defined metric to calculate similarity (Vinyals et al. (2016) ), relation network (Sung et al. (2018) ) and a deep comparison network (Zhang et al. (2018) ) train deep neural networks to compare each query-support image pair. While previous methods adopted the conceptual representation of the first moment (Snell et al. (2017) ), CovaMNet (Li et al. (2019) ) adopts the second moment rather than the first moment for feature description. Unlike the above methods, multi-level metric learning (Chen et al. (2022) ) measures the similarity at three different feature levels. According to the above analysis, most existing methods ignore the noisy samples, resulting in biased class representations. To solve this issue, this paper proposes a more accurate prototype estimate method to improve the few-shot image classification performance.

2. RELATED WORK

Transductive few-shot learning. In general, inductive few-shot is employed when data acquisition is expensive, and transductive few-shot is applied when data labeling is expensive (Bendou et al. (2022) ). Some studies have tackled the problem by utilizing the additional knowledge from the query dataset or extra unlabeled examples in a transductive setting (Wang et al. (2020) ; Nichol et al. (2018) ). However, they share knowledge between query datasets via batch normalization rather than explicitly modeling the transductive setting as in (Flennerhag et al. (2020) ). Task-adaptive feature sub-space learning (TAFSSL) (Lichtenstein et al. (2020) ) looks for the discriminative feature sub-spaces for few-shot classification tasks. In contrast to unidirectional label propagation, mutual centralized learning (MCL) considers query and support dataset features as bipartite data and avoids self-reinforcements (Liu et al. (2022) ). Inspired by transductive few-shot learning, unlabeled samples are employed to estimate the prototype and enrich the feature representation. samples of each class are used as support set S, K ′ are selected as query dataset Q from the remaining samples in class C. Then the f θ f () can be fine-tuned on query dataset. During each episode, estimating the mean-based prototype pi by averaging the labeled support features. Furthermore, AINet is proposed to learn the class prototype p ′ i , which is derived from the features f θ f (x) of the support and query samples. To obtain more mutual information, the fusion prototype p i is calculated using the Gaussian-based fusion method. Finally, the cosine similarity of features f θ f (x) and p i is calculated to determine the probability that each sample x ∈ Q belongs to class i.

3. METHOD

Meta-testing stage. The same as the meta-training, and classification task is performed on D novel .

3.2.1. ADAPTIVE INDUCTION NETWORK

The induction module (IM) (Geng et al. (2019) ) learns the class-level relationship by considering features and classes to be local-global relationships, but because of the diversity and incompleteness of the support sample, every support sample contributes differently to the class representation when it faces different target query samples. In order to learn a more representative class vector and reduce sample noise, we propose an AINet that pays more attention to effective instances for current query samples. The details of the AINet are shown in Algorithm 1. Using the multi-head self-attention mechanism, the support vector z s ij and query vector z q are concatenated to calculate the relationship score; each support vector has its weight attached to the current query vector. Then, we apply dynamic routing to obtain a class vector. The process adjusts the connection's strength dynamically and makes sure that the sum of the coupling coefficients d i between class i and all of its support samples is 1. The difference is that when adjusting the logits of coupling coefficients in the last step of every iteration, we consider not only the consistency of class candidate vectors and sample prediction vectors but also the relationship between query and support vectors. Algorithm 1 Adaptive Induction Network Require: sample vector z s ij in support dataset S and a vector z q in query dataset Q, initialize the logits of coupling coefficients: b ij = 0 Ensure: Class vector p ′ i for all samples j = 1, ..., K in class i: z ij = Concat(z s ij , z q ), z ij is equivalent to concatenate the vector z s ij and z q s ij = sof tmax( zij z T ij √ d )z ij , where d is the dimension of z ij ẑs ij = squash(W s z s ij + b s ), where W s is transformation weights, b s denotes bias for r iterations do d i = softmax(b i ) p ′ i = j d ij • ẑs ij , where ẑs ij is the prediction vector, p ′ i is the class candidate vector p ′ i = squash(p ′ i ) = ∥p ′ i ∥ 2 1+∥p ′ i ∥ 2 p ′ i ∥p ′ i ∥ , where squash is a non-linear squashing function for all sample j = 1, ..., K in class i: b ij = b ij + s ij • tanh(ẑ s ij • p ′ i ) end for return p ′ i 3.2.2 PROTOTYPE FUSION When the number of training samples is limited, p ′ i is more accurate because the model needs to focus on more representative features, and pi is more representative as the number of samples increases because the model only considers global features and ignores local features. This means that pi and p ′ i can learn mutual affiliations with each other (Zhang et al. (2021) ). In order to address the aforementioned issues, a prototype fusion algorithm is proposed to reduce the prototype bias. We assume that the estimated prototype has a Gaussian distribution, and the distributions are independent of each other because samples in the pre-trained space are continuous and clustered. Algorithm 2 describes the Gaussian-based prototype fusion. To follow a multivariate normal distribution (Yang et al. (2021) ), the input features are preprocessed using the Yeo-Johnson transformation (Weisberg (2001) ). The Yeo-Johnson transformation can reduce the heteroskedasticity of random variables and increase their normality, resulting in a probabil-ity density function with a similarity to the normal distribution. At the same time, the Yeo-Johnson transformation can be applied to samples with zero and negative features, making it suitable for statistical analysis of random variables based on the normal assumption, as follows in Equation 1. f θ f (x) =          [(fθ f (x)+1) λ -1] λ , λ ̸ = 0, f θ f (x) ≥ 0 log(f θ f (x)) + 1), λ = 0, f θ f (x) ≥ 0 - [(-fθ f (x)+1) 2-λ -1] 2-λ , λ ̸ = 2, f θ f (x) < 0 -log(-f θ f (x) + 1), λ = 2, f θ f (x) < 0 (1) where f θ f (x) is the feature to be transformed and the λ is employed to correct the distribution. Then, the mean-based prototype of pi should be estimated by averaging the features of the support labeled samples, it can be calculated by Equation 2. pi = 1 |S i | x∈Si f θ f (x) where S i represents the support dataset extracted for the class i, and f θ f (x) is the feature of support dataset. We assume the pi follows a Gaussian distribution with a mean μi and diagonal covariance diag( σ2 i ), and p ′ i is a sample from N (µ ′ i , diag(σ ′2 i )). To improve the class representation of the model, learn a Gaussian distribution with mean μi + µ ′ i and diagonal covariance diag( σ2 i + σ ′2 i ), then the mean is used to calculate the fusion prototype p i , as shown in Equation 3. θ ∼ N ( μi , diag( σ2 i ), θ ′ ∼ N (µ ′ i , diag(σ ′2 i )), θ ∼ N ( μi + µ ′ i , diag( σ2 i + σ ′2 i ) Transductive few-shot learning method is used to calculate the μi and µ ′ i (Liu et al. ( 2020)) by leveraging the unlabeled samples. When the class prototype is pi or p ′ i , the Equations 4 and 5 can be used to calculate the probability of x ∈ S ∪ Q, where S is the support dataset with a few labeled samples and Q is the query dataset with unlabeled samples. P (y = i | x) = e d(f θ f (x)), pi) c e d(f θ f (x), pc) P ′ (y = i | x) = e d(f θ f (x), p ′ i ) c e d(f θ f (x)), pc) where d() is the cosine similarity. Then, μi and µ ′ i can be calculated by regarding P (y = i | x) and P ′ (y = i | x) as the weights, as shown in Equation 6 and 7.

μi = 1

x∈S∪Q P (i | x) x∈S∪Q P (i | x)f θ f (x) µ ′ i = 1 x∈S∪Q P (i | x) x∈S∪Q P ′ (i | x)f θ f (x) Finally, the fusion prototype of p i can be obtained by μi and µ i , as shown in Equation 8. p i = µ i = μi + µ ′ i (8) 4 EXPERIMENTAL SETUP The classical 5-way 1/5-shot episodic in few-shot task settings are adopted. The query dataset contains 6 images per class during the meta-training stage, 15 test samples during the meta-testing stage, and 10,000 tasks are randomly constructed. Then test the task and calculate the average classification accuracy of top-1 and the 95% confidence interval as the final result.

4.2. IMPLEMENTATION DETAILS

The experiment is conducted on the feature extractor of ResNet-12 with 640-dimensional for the tieredImageNet. Each residual block contains three 3 × 3 convolutional layers and a shortcut connection. The WRN-28-10 with a layer number of 28 and a width of 10 is used for tieredImageNet and the extracted features are 512-dimensional. Average pooling is applied at the last block of each architecture to get feature vectors (Mangla et al. (2020) ). In the pre-training stage, the base class dataset is trained on 100 epochs with a batch size of 128. SGD with a momentum of 0.9 and weight decay of 0.0005 is adopted as the optimizer to train the feature extractor of ResNet-12, while the Adam optimizer is used for WRN-28-10. In the meta-training stage, data augmentation techniques are used, including random cropping, color jittering, and horizontal flipping. The model is metatrained for 60 epochs, with each epoch containing 1000 episodes and an initial learning rate of 0.1. When the epochs are set to 20, 40, and 50, the learning rate changes to 0.006, 0.0012, and 0.00024, respectively. λ is set to 0.5 in the Yeo-Johnson transform, and 3 iterations were used for the AINet.

4.3.1. COMPARISON WITH STATE-OF-THE-ART METHODS

Tables 1 and 2 show the 5-way 1/5-shot classification results of the FP AINet and state-of-theart few-shot learning methods on the miniImageNet and tieredImageNet, respectively. Table 1 shows that the FP AINet achieves better performance on miniImageNet compared with comparison methods. In the 5-way 1/5-shot settings, the accuracy of the FP AINet reaches 72.13% and 84.29%, respectively. Compared to the suboptimal methods Curvature Generation (Gao et al. (2021) ) and UniSiam (Lu et al. (2022) ), it increased by about 0.34% and 0.89%, respectively. On the tieredImageNet, the accuracy of FP AINet on 1-shot is higher than 0.49% of the second-best models of BD-CSPN (Liu et al. (2020) ), and higher than 0.29% EPNet(Rodríguez et al. ( 2020)) on a 5-shot setting. The FP AINet has such an improvement attributed to considering the more important samples. Moreover, the Gaussian-based fusion algorithm alleviates the prototype error and facilitates learning the optimal prototype by exploring the unlabeled samples. Table 3 shows the comparison results of the FP AINet with the main few-shot learning methods on the CIFAR-FS. In the 5-way 1-shot setting, the accuracy of the FP AINet reaches 81.92%, 0.32% higher than the suboptimal method of SSR (Shen et al. (2021) ), which proves that FP AINet can handle extremely few-shot classification tasks better. In the 5-way 5-shot setting, the accuracy of FP AINet is 89.38%, which is 0.38% higher than the suboptimal method EASY (Bendou et al. (2022) ). The FP AINet has the highest accuracy with the same backbone, and accurate prototypes are more effective than fully extracted features. Furthermore, accuracy on the 5-shot setting is significantly higher than on the 1-shot setting. The main reason is that fewer annotated samples result in inaccurate prototype estimation, whereas a 5-shot can yield a more representative prototype estimation. It is verified that the FP AINet can better handle the few-shot learning task with a limited amount of data. The prototype features of the novel class are expressed more abundantly and accurately by fusing the prototypes. Table 2 : 5-way 1-shot/5-shot accuracy (%) on tieredImageNet with 95% confidence intervals. 4 summarizes the results of FP AINet and shows that each component is important in few-shot image classification, giving improvements over the state-of-the-art on the miniImageNet. Among them, (i) represents classification using only PN, (ii) is classification result of induction module, (iii) denotes classification using only AINet, and (iv) represents the Gaussian-based fusion algorithm. Obviously, in the 5-way 1-shot setting, if neither module is used, the accuracy drops by more than 10%. The prototype fusion algorithm of FP AINet achieves better performance than AINet. Adaptive Induction Network. It can be seen from (iii) in Table 4 that in the 5-way 1-shot, the classification result of AINet is better than the PN, and the main reason is that the module calculates the prototype by using query samples and selection. At the same time, the induction prototype method obtains class-level information and automatically adjusts the coupling coefficient according to the input, which is suitable for few-shot learning and can achieve good results in the presence of noise. In 5-way 5-shot, with the increase of samples, the mean-based prototype obtains better class representation. The results demonstrate that paying more attention to effective support samples is an important factor in the few-shot classification problem. Prototype fusion. The accuracy of the AINet is improved by about 9% and 6%, respectively, in the 5-way 1/5-shot settings, as shown by the model of (iv) in Table 4 . The results indicate that fusion prototypes can improve model performance and alleviate bias in prototype estimates. The primary argument is that prototype fusion utilizes more samples, which can more effectively address the issues of sample noise and incompleteness in few-shot learning. The results show the necessity and effectiveness of learning an optimal class prototype. The prototypes generated by the FP AINet are visualized using t-distributed stochastic neighbor embedding (t-SNE). A 5-way 1-shot task of miniImagenet is shown in Figure 3 

B COMPARISON OF COMPUTATION COST

Table 5 shows a detailed analysis of different modules in our method and classification results on miniImageNet. Compared with the baseline of PN, FLOPs increased by 16M. The main reason is that the attention mechanism introduces some attention parameters. The operation of fusion almost without additional calculations. The parameter of ours has increased by 410.9 K.

C COMPARISON OF DIFFERENT BACKBONES

In order to explore the influence of feature embedding vectors, the depth of the backbone network is changed and the same settings were used for the three models. It can be seen from Table 6 that the



(a) Prototype with noisy samples. (b) Test accuracy on miniImageNet.

Figure 1: Different prototype models. (a) shows the sample is misclassified by the PN. Different colors represent different classes. The orange circle denotes the sample to be classified. (b) illustrates the test accuracy of different prototypes on the 5-way k-shot.

PROBLEM DEFINITION A few-shot classification setting includes two datasets: the base class dataset D base with abundant labeled images and the novel class dataset D novel with few labeled data. Suppose D base = sampled from the base class C base , y base t ∈ C base is the label of x base t , there is no intersection between the base class and novel class, that is C base ∩ C novel = ∅, C base ∪ C novel = C. In each iteration process, one of the episodes means that N classes are selected at random, and each class contains K labeled samples as a support dataset S = {(x s , y s )} N ×K s=1 with a few labeled samples. The query set Q = {(x q , y q )} N ×K ′ q=N ×K+1 contains examples of the same N classes in S, K ′ is the quantity of each class in Q. The model needs to predict a class label for a query sample given N support classes, each containing K support samples.3.2 OVERALL ARCHITECTUREThe FP AINet consists of three stages, including the pre-training stage, the meta-training stage, and the meta-testing stage. An overview of the FP AINet is provided in Figure2.

Figure 2: An overview of FP AINet.

, where circles represent query samples and different colors denote different classes, stars represent PN features, pentagons are AINet features, and squares represent FP AINet. The prototypes generated by FP AINet are much closer to the class center, which can effectively learn the representation of a prototype and improve the capacity of the support dataset.

Figure 3: t-SNE visualization of different prototype.

Figure 4: The accuracy of different prototype methods on different datasets.

Figure 5: Accuracy of different values of λ on miniImageNet. The query features before and after the Yeo-Johnson transformation are shown in Figure 6. Different colors represent categories. It is observed that the distribution before transformation is more skewed. The distribution after Yeo-Johnson transformation can very well satisfy the Gaussian assumption. It provides a powerful means of reducing skewness.

Figure 6: Feature transformation.

Fusion prototype p i for each episode iteration do Create the episodic tasks using S and Q, fine-tuned the feature extractor f θ f () Estimate the mean-based prototype pi with Equation 2 Calculate the class vector p ′ i with Algorithm 1 Use pi and p ′ i to calculate the probability of x ∈ S ∪ Q with Equation 4 and 5, respectively Calculate μi and µ ′ i by P (y = i | x) and P ′ (y = i | x) with Equation 6 and 7, respectively Estimate the fusion prototype p i by μi and µ ′ dataset is divided into 64, 16, and 20 classes for training, validation, and testing. tieredImageNet (Ren et al. (2018)) consists of a total of 608 classes, which are divided into 34 higher-level classes. The training dataset contains 20 higher-level classes, 351 fine-grained classes; 6 higher-level classes, 97 fine-grained classes as validation sets; 8 higher-level classes, and 160 fine-grained classes as the test datasets. The image size of miniImageNet and tieredImageNet is 84 × 84 × 3 . CIFAR FS(Bertinetto et al. (2019)) contains 100 classes and 600 images in each class, including 64 classes of training datasets, 16 classes of validation datasets, and 20 classes of test datasets. The image size is unified to 32 × 32 × 3.

5-way 1/5-shot accuracy (%) on miniImageNet with 95% confidence intervals.The best two results are highlighted and underlined.

5-way 1-shot/5-shot accuracy (%) on CIFAR FS with 95% confidence intervals.

Ablation studies of 5-way 1/5-shot on miniImageNet.

