Print

发布时间: 2019-09-16
摘要点击次数:
全文下载次数:
DOI: 10.11834/jig.180629
2019 | Volume 24 | Number 9




    图像分析和识别    




  <<上一篇 




  下一篇>> 





代表特征网络的小样本学习方法
expand article info 汪荣贵, 郑岩, 杨娟, 薛丽霞
合肥工业大学计算机与信息学院, 合肥 230601

摘要

目的 小样本学习任务旨在仅提供少量有标签样本的情况下完成对测试样本的正确分类。基于度量学习的小样本学习方法通过将样本映射到嵌入空间,计算距离得到相似性度量以预测类别,但未能从类内多个支持向量中归纳出具有代表性的特征以表征类概念,限制了分类准确率的进一步提高。针对该问题,本文提出代表特征网络,分类效果提升显著。方法 代表特征网络通过类代表特征的度量学习策略,利用类中支持向量集学习得到的代表特征有效地表达类概念,实现对测试样本的正确分类。具体地说,代表特征网络包含两个模块,首先通过嵌入模块提取抽象层次高的嵌入向量,然后堆叠嵌入向量经过代表特征模块得到各个类代表特征。随后通过计算测试样本嵌入向量与各类代表特征的距离以预测类别,最后使用提出的混合损失函数计算损失以拉大嵌入空间中相互类别间距减少相似类别错分情况。结果 经过广泛实验,在Omniglot、miniImageNet和Cifar100数据集上都验证了本文模型不仅可以获得目前已知最好的分类准确率,而且能够保持较高的训练效率。结论 代表特征网络可以从类中多个支持向量有效地归纳出代表特征用于对测试样本的分类,对比直接使用支持向量进行分类具有更好的鲁棒性,进一步提高了小样本条件下的分类准确率。

关键词

小样本学习; 度量学习; 代表特征网络; 混合损失函数; 微调

Representative feature networks for few-shot learning
expand article info Wang Ronggui, Zheng Yan, Yang Juan, Xue Lixia
School of Computer Science and Information Engineering, Hefei University of Technology, Hefei 230601, China

Abstract

Objective Few-shot learning aims to build a classifier that recognizes new unseen classes given only a few samples. The solutions are mainly in the following categories:data augmentation, meta-learning, and metric learning. Data augmentation can be used to reduce certain over-fitting given a limited data regime in a new class. The corresponding solution is to augment data in the feature domain as hallucinating features. These methods exert a certain effect on few-shot classification. However, due to the extremely small data space, the transformation mode is considerably limited and cannot solve over-fitting problems. The meta-learning method is suitable for few-shot learning because it is based on the high-level strategy of learning similar tasks. Some methods learn good initial values, some learn task-level update strategies, and others construct external memory storages to remember past information for comparison during testing. The few-shot classification results of these methods are superior, but the network structure is increasingly complicated due to the use of RNNs(recurrent neural networks). The efficiency is also low. The metric learning method is simple and efficient. It first maps a sample to the embedding space and then computes the distance to obtain the similarity metric to predict the category. Some approaches improve the representation of features in the embedding space, some use learnable distance metrics to compute distance for loss, and others combine meta-learning methods to improve accuracy. However, this type of method fails to summarize representative features from multiple support vectors in a class to effectively represent the class concept. This drawback limits the further improvement of the accuracy of small sample classification. To address this problem, this study proposes a representative feature network. Method The representative feature network is a metric learning strategy based on class representative features. It uses the representative features learned from a support vector set in a class to express the class concept effectively. It also uses mixture loss to reduce the misclassification of similar classes and thus achieve excellent classification results. Specifically, the representative feature network includes two modules. The embedded vector of a high abstraction level is extracted by the embedded module, and then the representative feature per class is obtained by the representative feature module by inputting stacked support vector sets. The class representative feature fully considers the influence of the embedded vector of the support samples on the basis of the target that may or may not be obvious. The use of network learning to assign different weights to each embedded support vector can effectively avoid misclassification caused by the bias effects of representative features for unobvious target samples. Then, the distances from the embedded query vectors to each class representative feature are calculated to predict the class. In addition, the mixture loss function is proposed for the misclassification of similar classes in the embedded space. The cross-entropy loss combined with the relative error loss function is used to increase the inter-class distances and reduce the similar class error rate. Result After extensive experiments, the Omniglot, miniImageNet, and Cifar100 datasets verify that the model achieves state-of-the-art results. For the simple Omniglot dataset, the five-way, five-shot classification accuracy is 99.7%, which is 1% higher than that of the original matching network. For the complex miniImageNet dataset, the five-way, five-shot classification accuracy is 75.83%, which is approximately 18% higher than that of the original matching network. Representative features provide approximately 8% improvement, indicating that it can effectively express the prototype by distinguishing the contribution of different support vectors, the target of which may or may not be obvious. Mixture loss provides approximately 1% improvement, indicating that it can reduce some misclassification of similar classes in the testing set. However, the improvement is unremarkable because similar samples are uncommon in the dataset. The last 9% improvement is due to the fine-tuning on the test set, indicating that the advantage of the skip connection method benefits loss propagation relative to the original connection between the network module methods. For the Cifar100 dataset, the five-way, five-shot classification accuracy is 87.99%, which is 20% higher than that of the original matching network. Moreover, the high training efficiency is maintained while the performance is significantly improved. Conclusion To address the problem of extremely simple original embedding networks for extracting high-level features of samples, the improved embedding networks in a representative feature network use a skip connection structure so as to deepen the network and extract advanced features. To address the problem of the noise support vector that disturbs the classification accuracy of a testing sample, the representative feature network can effectively summarize the representative features from multiple support vectors in a class for classifying query samples. Compared with the performance when support vectors are used directly, the classification performance when representative features are used is more robust, and the classification accuracy under few-shot samples is further improved. In addition, the mixture loss function proposed for the classification problem of similar classes is used to enlarge the distance between categories in the embedded space and reduce the misclassification of similar classes. Detailed experiments are carried out to verify that these improved methods achieve great performance in few-shot learning tasks for the Omniglot, miniImageNet, and Cifar100 datasets. At the same time, the representative feature network presents improvement. For embedding networks, advanced structures, such as dense connections or se modules, must be included in future work to further improve the results.

Key words

few-shot learning; metric learning; representative feature network; mixture loss function; fine-tuning

0 引言

近来,深度学习(deep learning)在具有大数据集的图像分类[1-4]、目标检测[5-8]和机器翻译[9-12]等任务上取得了重大进展,错误率越来越低,其中部分领域分类识别的定位能力已超过人类。这些成就都是基于深度模型、使用大量标签样本进行训练、迭代更新模型参数取得的。这种类型的优化在数据集较小时效果较差,因为仅是简单地基于小数据集训练深度网络会使网络严重过拟合。在这种情况下,小样本学习(few-shot learning)[13-15]应运而生。小样本学习任务是指对于训练中没有的新类,在仅给出少量标注样本的情况下,也能够正确识别。这对于深度学习可能是困难的,但是对于人类来说是相当容易的,甚至幼儿也能在给出的一个从未见过的老虎的几张图片中概括出老虎的概念,从而能很好地泛化,识别其他图片中的老虎。完成这个任务的动机不仅于此,而且会有许多应用,如医疗影像中罕见病例的识别分类用于辅助诊断、海量监控视频中嫌疑人搜索识别用于辅助侦察等。最明显的一点即是它仅需几张标签样本就能有较好的分类结果,不再需要百万千万级别的标签样本,从而大大减轻样本标注的工作量。对于该类小样本学习任务,首先想到的是迁移学习(transfer learning)[16-18],其可以通过预先在大数据集上训练,然后在目标小数据集上微调网络得到对目标类别的一个较好的识别率。然而研究证明,当目标类别与训练样本类别差异较大时,预先训练网络的表现大大降低[19]。这种情况需要的是抽象类层次概念,而不是样本层次。此外,由于目标小数据集每类只有几张(few-shot)甚至一张(one-shot)标注样本,直接微调仍然无法很好地学习到目标小数据集的类别概念。

在小样本学习任务上的解决方案主要有数据增强(data augmentation)、元学习(meta learning)和度量学习(metric learning)。考虑到新类中的数据量极少,数据增强可以用来减少一定的过拟合,相应的方法有在特征层面进行数据增强,如幻想特征[20-21],在小样本分类上有一定的效果提升,但是由于数据空间极小,导致变换模式很有限,并不能解决过拟合问题。而元学习方法[22-26]由于其建立在任务基础上学习相似任务间的高层策略,比较适合小样本学习,受到广泛青睐,通过学习好的初始化条件[22-23],通过任务层面的更新策略[22]或RNN (recurrent neural network)构建外部记忆存储器(external memory)[24-25]来记住大量样本以供测试时比对。这些方法都取得了较好的小样本分类效果,但是由于过多使用RNN导致网络结构较为复杂,且效率较低。而度量学习方法更加简单高效,首先通过嵌入网络(embedding network)学习样本的嵌入向量(embedded vector),然后在嵌入空间(embedding space)中直接求解最近邻达到预测分类的目的[27-32],通过使用分段(episode)[28]训练方式、改进的嵌入空间[29-30]、可学习的距离度量[31-32]进一步提高小样本分类精度。

度量学习方法中较有代表性的成果是Vinyals等人[28]提出的匹配网络(matching networks),使用注意力机制基于支持集(support set)上学习的嵌入网络来预测测试集(query set)的类别。该模型的最大贡献是在训练期间使用分段的抽样小批量数据,模拟测试任务。这种方式使得训练环境和测试环境更加接近,从而提高了测试时的泛化表现,此种训练方法在后期研究中被广泛使用。另一贡献是提出适用于小样本学习任务的miniImageNet数据集,也被广泛用作实验数据集。Snell等人[29]进一步挖掘嵌入空间中类嵌入向量间的联系,认为每个类别都存在一个原型(prototype)表达,相应类的嵌入向量聚集在原型周围,而原型是支持集嵌入向量的均值。据此提出原型网络(prototypical networks),基于支持集嵌入向量求出类原型后,分类问题就变成在嵌入空间中测试样本嵌入向量与类原型的最近邻,取得了良好效果。

本文基于匹配网络模型训练方法[28],受原型[29]思想启发提出代表特征网络,并且针对其使用简单平均求类原型,不能很好地评估类中各个支持集样本特征向量对于类原型的不同贡献,提出代表特征概念。代表特征网络包含两个串联模块:嵌入模块和代表特征模块。首先使用嵌入模块提取各个支持集样本嵌入向量,然后堆叠多个嵌入向量输入代表特征模块,得到最终的类代表特征向量。另外考虑到嵌入空间中类间差异较小导致错误分类的情况,提出混合损失函数。使用交叉熵损失联合相对误差损失函数,使得在正确分类的同时,增大各个类嵌入空间间距,降低相似类别错分概率。实验测试时在Omniglot[15]、miniImageNet[28]和Cifar100[33]数据集上使用支持集微调达到目前已知的最好效果。

1 本文方法

本文处理小样本分类任务的基本流程是:1)将分段中的支持集和测试样本输入到代表特征网络中,得到相应的类代表特征向量和测试样本嵌入向量;2)计算测试样本嵌入向量和各代表特征向量的余弦距离;3)使用提出的混合损失函数根据距离得到损失,随后反向传播求梯度更新网络。如图 1所示,嵌入空间中的长条代表各个样本的嵌入向量,用红、绿、蓝表示类别。样本首先经过嵌入模块得到各个样本的嵌入向量,然后经过代表特征模块得到类代表特征向量。图 1下方的虚线代表反向传播过程。训练采用分段训练方式。

图 1 代表特征网络总体架构图
Fig. 1 Representative feature network architecture

1.1 代表特征网络

已有的度量学习方法[28-29]嵌入网络较为简单,多是卷积层的简单堆叠,不能很好地学习到miniImageNet[28]等复杂数据集中的非线性,为此,本文使用带有跳过连接结构的深层嵌入模块提高嵌入向量的抽象层次。针对原型方法[29]简单使用嵌入向量均值作为类代表特征不具有鲁棒代表性导致错误分类的情况,本文使用代表特征模块学习得到鲁棒性更强的类代表特征。串联嵌入模块和代表特征模型即可得到本文提出的代表特征网络。

1.1.1 嵌入模块

在已有的度量学习方法[28-29]中嵌入简单的4层CNN (convolutional neural network)网络结构,考虑到这种CNN结构对复杂数据集中的非线性的学习能力有限,而增加网络深度被证明可以更好地提取目标特征[2],跳过连接(skip connection)结构[1]可以在加深网络的同时避免退化(degradation)现象[1],本文借鉴He等人[1]提出的残差网络(residual network),使用改进的深层嵌入模块,其中使用的残差块包含3个3×3卷积层(convolutional layer),然后用一个跳过连接的结构,后面再接上ReLU激活和2×2最大池化层(maximum pooling layer)。残差块和整个嵌入模块的结构如图 2所示。图 2(a)是本文嵌入网络中使用的残差块的结构,图 2(b)是本文使用残差块的整个嵌入模块的结构。输入图片,使用1×1的卷积层把深度扩展到512,然后接全局平均池化层(global average pooling layer)输出512维嵌入向量。

图 2 嵌入模块结构
Fig. 2 Embedding module architecture ((a)residual block; (b)embedding module)

残差网络的优势主要在于其跳过连接结构,该结构是在标准结构中增加了一个恒等映射, 即

$ \mathit{\boldsymbol{y}} = R(\mathit{\boldsymbol{x}}) + \mathit{\boldsymbol{x}} $ (1)

式中,$\mathit{\boldsymbol{x}}$代表网络模块的输入向量,$\mathit{\boldsymbol{y}}$代表输出向量,$R(\mathit{\boldsymbol{x}})=\mathit{\boldsymbol{y}}-\mathit{\boldsymbol{x}}$即为要学习的残差映射,而标准网络中的输出由输入经过网络模块直接映射得到:$\mathit{\boldsymbol{y}}=F(\mathit{\boldsymbol{x}})$。学习残差$R(\mathit{\boldsymbol{x}})$比直接学习$F(\mathit{\boldsymbol{x}})$要容易得多。理论上分析可知,如果网络已经达到最优,继续加深网络,残差映射$R(\mathit{\boldsymbol{x}})$会被置零处理,相当于模块进行了恒等映射$\mathit{\boldsymbol{y}}=\mathit{\boldsymbol{x}}$,使其一直保持最优状态。而让网络学习零,则很简单。

具体地,本文嵌入模块在使用跳过连接时,当残差块的输入输出通道相同时,直接执行加操作,即

$ \mathit{\boldsymbol{y}} = R\left( {\mathit{\boldsymbol{x}},{\mathit{\boldsymbol{w}}_i}} \right) + \mathit{\boldsymbol{x}} $ (2)

式中,残差映射$R(\mathit{\boldsymbol{x}}, \mathit{\boldsymbol{w}}_{i})$始终与输出$\mathit{\boldsymbol{y}}$通道保持一致,$\mathit{\boldsymbol{w}}_{i}$表示模块中卷积层参数。

当通道不同时,需要对输入向量进行线性变换$\mathit{\boldsymbol{W}}$使通道与输出通道一致,使用$1×1$卷积操作实现,即

$ \mathit{\boldsymbol{y}} = R\left( {\mathit{\boldsymbol{x}},{\mathit{\boldsymbol{w}}_i}} \right) + \mathit{\boldsymbol{Wx}} $ (3)

对比ResNet[1]结构,本文提出的嵌入模块更适用于小样本学习任务。本文设计网络的出发点是使用跳过连接来加深网络,更好地提取特征,并且尽量减少参数数量以防过拟合。首先,该模块使用最大池化来减半特征图的大小而不是ResNet[1]中使用更大步长的卷积层,参数量总计减少了460 k。其次,该模块最后使用$1×1$卷积层扩展深度代替全连接层,本文认为全卷积网络更有利于提取特征,实验证明,该方法提高了0.2%的精度。最后,该模块带参层数共13层,且选择了最合适的卷积核数量。这些方法达到了在保证深度和效果的同时,尽可能减少参数的目的。相比最小的ResNet-18[1] (参数量为11.13×106),本文(参数量为1.24×106)减少了88.86%的参数,并且取得了更好的效果,如表 1所示。由于输入尺寸的不同,本文未使用ResNet-18最后的两个下采样层。

表 1 不同嵌入网络结构对比
Table 1 Comparison of various embedding networks architecture

下载CSV
网络名 结构 带参层数量 参数数量
简单CNN $\left[ {\begin{array}{*{20}{c}} {{\rm{C}}:3 \times 3, 64}\\ {{\rm{MP}}:2 \times 2} \end{array}} \right]\quad \left[ {\begin{array}{*{20}{c}} {{\rm{C}}:3 \times 3, 64}\\ {{\rm{MP}}:2 \times 2} \end{array}} \right]\quad \left[ {\begin{array}{*{20}{c}} {{\rm{C}}:3 \times 3, 64}\\ {{\rm{MP}}; 2 \times 2} \end{array}} \right]\quad \left[ {\begin{array}{*{20}{c}} {{\rm{C}}:3 \times 3, 64}\\ {{\rm{MP}}:2 \times 2} \end{array}} \right]$ 4 0.11×106
ResNet-18 $\left[ {\begin{array}{*{20}{c}} {{\rm{C}}:3 \times 3, 64}\\ {{\rm{C}}:3 \times 3, 64} \end{array}} \right] \times 2\left[ {\begin{array}{*{20}{l}} {{\rm{C}}:3 \times 3, 128}\\ {{\rm{C}}:3 \times 3, 128} \end{array}} \right] \times 2\left[ {\begin{array}{*{20}{l}} {{\rm{C}}:3 \times 3, 256}\\ {{\rm{C}}:3 \times 3, 256} \end{array}} \right] \times 2\left[ {\begin{array}{*{20}{l}} {{\rm{C}}:3 \times 3, 512}\\ {{\rm{C}}:3 \times 3, 512} \end{array}} \right] \times 2\; \; \; \; \left[ {{\rm{GAP}}} \right]$ 16 11.13×106
嵌入模块
(本文)
$\left[ {\begin{array}{*{20}{c}} {{\rm{C}}:3 \times 3, 64}\\ {{\rm{C}}:3 \times 3, 64}\\ {{\rm{C}}:3 \times 3, 64}\\ {{\rm{MP}}:2 \times 2} \end{array}} \right]\quad \left[ {\begin{array}{*{20}{c}} {{\rm{C}}:3 \times 3, 96}\\ {{\rm{C}}:3 \times 3, 96}\\ {{\rm{C}}:3 \times 3, 96}\\ {{\rm{MP}}:2 \times 2} \end{array}} \right]\quad \left[ {\begin{array}{*{20}{c}} {{\rm{C}}:3 \times 3, 128}\\ {{\rm{C}}:3 \times 3, 128}\\ {{\rm{C}}:3 \times 3, 128}\\ {{\rm{MP}}:2 \times 2} \end{array}} \right]\quad \left[ {\begin{array}{*{20}{c}} {{\rm{C}}:3 \times 3, 128}\\ {{\rm{C}}:3 \times 3, 128}\\ {{\rm{C}}:3 \times 3, 128}\\ {{\rm{MP}}:2 \times 2} \end{array}} \right]\quad \left[ {\begin{array}{*{20}{c}} {{\rm{C}}:1 \times 1, 512}\\ {{\rm{GAP}}} \end{array}} \right]$ 13 1.24×106
注:C表示卷积层,MP表示最大池化层,GAP表示全局平均池化层。

1.1.2 代表特征模块

在一个小样本学习任务中,对一个有$N$个标注样本的支持集$\mathit{\boldsymbol{S}}=\{(\mathit{\boldsymbol{x}}_{1}, y_{1}), …, (\mathit{\boldsymbol{x}}_{N}, y_{N})\}$,这里$\mathit{\boldsymbol{x}}_{i}∈{\bf {R}}^{D}$是样本的$D$维输入向量,$y_{i}∈\{1, …, h\}$是相应的标签,$h$是标签总数。$\mathit{\boldsymbol{S}}_{k}$表示支持集中类别为$k$的样本集,$N_{S}$表示支持集一类中的样本数量。首先每个样本输入向量经过嵌入网络学习一个嵌入函数映射$f_{θ}:{\bf {R}}^{D}→{\bf {R}}^{M}$,其中$θ$是嵌入网络的可学习参数,$M$维空间是映射变换后的嵌入向量空间。在原始的原型网络[29]中其计算的类原型(class prototype)为支持集类中嵌入向量的均值,即

$ {\mathit{\boldsymbol{c}}_{k{\rm{\_src}}}} = \frac{1}{{\left| {{\mathit{\boldsymbol{S}}_k}} \right|}}\sum\limits_{\left( {{x_i},{y_i}} \right) \in {\mathit{\boldsymbol{S}}_k}} {{f_\theta }} \left( {{\mathit{\boldsymbol{x}}_i}} \right) $ (4)

原型网络[29]中提出用均值作为类原型的方法是基于伯格曼散度思想,文献[34]证明在特定空间中的一组点满足任意概率分布的情况下,这些点的均值点一定是空间中距离这些点的平均距离的最小值点。本文认为在小样本情况,尤其是测试环境下,样本对应嵌入空间的分布并不足以满足任意概率分布,所以认为原型位置不是简单地对类嵌入向量求取均值,类中嵌入向量有接近和不接近原型之分,即存在样本目标较为不明显的嵌入向量,由于目标不明显,例如目标前景较小、背景较大、目标被部分遮挡、样本图片中仅包含部分目标等情况,网络并不能在此种复杂情况下很好地学习到对应目标特征,致使嵌入向量较为偏离类原型,其对原型的贡献与样本目标较为明显的嵌入向量不一致,据此本文提出使用代表特征网络模块来学习更好地表达原型。由于学习目标较为简单,输入是具有高层抽象特征的嵌入向量,所以使用多层感知机结构,在保证效果的同时可以提高效率,代表特征模块的具体结构如图 3所示,输入为$N$个堆叠的嵌入向量,输出为相同维度的代表特征向量。

图 3 代表特征模块结构
Fig. 3 Representative feature module architecture

代表特征模块的本质是学习一个线性映射,即学习如何根据类支持集样本中目标的不同明显程度,给$N$个类支持向量分配合适的权重,映射得到相对均值更加鲁棒的类代表特征。学习的线性映射为

$ \mathit{\boldsymbol{y}} = \mathit{\boldsymbol{Wx}} + \mathit{\boldsymbol{b}} $ (5)

式中,$\mathit{\boldsymbol{W}}$是学习的权重矩阵,$\mathit{\boldsymbol{b}}$是偏置项。首先使用一层全连接层将$N$个嵌入向量映射到更高的维度$M$,这里使用128维。可以理解为从$N$个嵌入向量中学习得到$M$个嵌入向量($M>N$),其中每个嵌入向量都共享$N$个嵌入向量的信息,但是每个都不同。该层全连接层学习更高维的嵌入向量的意义在于分散$N$个嵌入向量的特征,使得特征分布更加均匀。由此,第2层全连接层从$M$个嵌入向量中提取的代表特征向量更加鲁棒。实验结果表明,在miniImageNet数据集5-way,5-shot任务上,先提高维度再提取代表特征的准确率比仅使用一层全连接层直接提取代表特征的准确率高0.6%。

实验表明,学习机制可以比简单取支持集向量均值更好地表达原型,效果示意如图 4所示。嵌入模块首先将样本映射到嵌入空间,嵌入空间中的圆代表各个样本的嵌入向量。图 4中共3类,每类5个支持样本,1个测试样本。可以看出,当某个目标不明显的样本对应的支持向量偏离类原型表达较远时,本文方法学习出的代表特征更有代表性,并且避免了该测试样本的分类错误。图 4中的实线表示使用代表特征距离最近的类别,虚线表示使用均值原型距离最近的类别。该模块接受支持集$\mathit{\boldsymbol{S}}_{k}$中各类堆叠的$N_{\rm {S}}$个嵌入向量作为输入,学习函数映射输出代表特征向量,即

图 4 代表特征效果示意图
Fig. 4 Diagram of representative feature's effects

$ {\mathit{\boldsymbol{p}}_k} = {R_{{\theta _2}}}\left( {{f_{{\theta _1}}}{{\left( {{\mathit{\boldsymbol{x}}_i}} \right)}_{\left( {{\mathit{\boldsymbol{x}}_i},{y_i}} \right) \in {\mathit{\boldsymbol{S}}_k}}}} \right) $ (6)

式中,$f_{θ_{1}}$是嵌入模块学习的函数映射,$θ_{1}$是嵌入模块的可学习参数。$R_{θ_{2}}$是代表特征模块学习的函数映射,$θ_{2}$是代表特征模块的可学习参数。

对于测试样本$\mathit{\boldsymbol{\hat x}}$,在代表特征上计算的注意核为

$ a\left( {\mathit{\boldsymbol{\hat x}},{\mathit{\boldsymbol{p}}_k}} \right) = \frac{{\exp \left( {c\left( {{f_{{\theta _1}}}(\mathit{\boldsymbol{\hat x}}),{\mathit{\boldsymbol{p}}_k}} \right)} \right)}}{{\sum\limits_{{k^\prime }} {\exp } \left( {c\left( {{f_{{\theta _1}}}(\mathit{\boldsymbol{\hat x}}),{\mathit{\boldsymbol{p}}_{{k^\prime }}}} \right)} \right)}} $ (7)

式中,$c$函数是余弦距离。

然后即可得到测试样本$\mathit{\boldsymbol{\hat x}}$的预测标签$\hat y$

$ \hat y = \sum\limits_{k = 1}^h a \left( {\mathit{\boldsymbol{\hat x}},{\mathit{\boldsymbol{p}}_k}} \right){y_k} $ (8)

1.1.3 特征提取路径

在分段训练方法中,每个分段均作为训练的一次迭代,一个分段包含支持集和测试集。支持集的数量构成依据的是小样本任务类型的数量($N$类way)和每个类型中样本的数量($K$个shot),例如5-way, 5-shot即表示一个分类的支持集由5类、每类5个样本,共25个样本构成。测试集的数量没有特定要求。由于支持集有多个类别,每类有多张,而小样本分类任务就是基于该支持集将测试样本归为支持集的某个类别,所以测试样本预测类别的优劣依赖于支持集类特征的优劣。据此支持集需要通过代表特征模块提取更有类代表性的特征向量,而测试样本则不需要通过代表特征模块。

具体来说,在本文的代表特征网络中,一个分段中的支持集和测试集在网络中通过的路径不同:支持集先经过嵌入模块得到各个样本的嵌入向量,然后堆叠同类向量输入代表特征模块得到代表特征向量;测试集仅经过嵌入模块提取出测试样本的嵌入向量,维度与代表特征向量维度相同。然后计算测试样本嵌入向量与各个代表特征向量的余弦距离,使用注意力机制选择距离最近的代表特征类别作为测试样本的预测类别,如图 5所示。嵌入模块用于提取基础特征,分段中的支持集和测试集都经过嵌入模块以提取嵌入向量,代表特征模块仅用于提取支持集代表特征向量。而后计算测试图片嵌入向量与各个类代表特征向量的余弦距离用于预测类别。

图 5 特征提取路径图
Fig. 5 Paths of extracting features

1.2 混合损失函数

对于测试样本$\mathit{\boldsymbol{\hat x}}$,其类别为$k$,一般损失函数为交叉熵损失函数(cross entropy loss function), 即

$ J(\mathit{\boldsymbol{\theta }}) = - {\log _2}{\mathit{\boldsymbol{P}}_\theta }(\mathit{\boldsymbol{y}} = \mathit{\boldsymbol{x}}|\mathit{\boldsymbol{\hat x}},\mathit{\boldsymbol{p}}) $ (9)

式中,$\mathit{\boldsymbol{P}}$是整个支持集嵌入向量的类代表特征集。

研究发现,分类任务中有部分类别较为相似的情况,如miniImageNet数据集[28]中的纽芬兰犬和戈登雪达犬类别,两者同属犬类;Cifar100数据集[33]中的仓鼠和老鼠类别,两者同属啮齿动物且相似度较大。由于部分类别的样本量极少,较容易导致该种相似类别错分,所以通过拉远各个类间距离可以有效降低相似类别错分概率。即优化后测试样本嵌入向量与同类代表特征的距离变近,与异类(相似类)代表特征的距离变远,从而避免对相似类别的错误分类。本文据此提出混合损失函数,在交叉熵损失中添加相对误差损失项,对分段中有$n$类支持集的小样本分类任务,有

$ \begin{array}{*{20}{c}} {J(\mathit{\boldsymbol{\theta }}) = - {{\log }_2}{\mathit{\boldsymbol{P}}_\theta }(\mathit{\boldsymbol{y}} = \mathit{\boldsymbol{x}}|\mathit{\boldsymbol{\hat x}},\mathit{\boldsymbol{p}}) + }\\ {\frac{1}{{n - 1}}\sum\limits_{i = 1}^n {\left( {1 - {z_i}} \right)a\left( {\mathit{\boldsymbol{\hat x}},{\mathit{\boldsymbol{p}}_i}} \right)} + {z_k}\left( {1 - a\left( {\mathit{\boldsymbol{\hat x}},{\mathit{\boldsymbol{p}}_k}} \right)} \right)} \end{array} $ (10)

式中,${z_i} = \left\{ {\begin{array}{*{20}{l}} 0&{{y_i} \ne k}\\ 1&{{y_i} = k} \end{array}, a\left({\mathit{\boldsymbol{\hat x}}, {\mathit{\boldsymbol{p}}_i}} \right)} \right.$是测试样本与$i$类代表特征的注意核。即在测试样本${\mathit{\boldsymbol{\hat x}}}$标签$y$$k$类的情况下,有:

1) 对支持集中标签$y_{i}$$k$的类代表特征,有$z_{i}=1$,则式(10)中的第2项为0,第3项为$(1-a(\mathit{\boldsymbol{\hat x}}, \mathit{\boldsymbol{p}}_{k}))$,而要使第3项变小(最小化损失),则要求$a(\mathit{\boldsymbol{\hat x}}, \mathit{\boldsymbol{p}}_{k})$变大,即要求$\mathit{\boldsymbol{\hat x}}$更接近$k$类代表特征$\mathit{\boldsymbol{p}}_{k}$

2) 对支持集中标签$y_{i}$$k$的类代表特征,有$z_{i}=0$,则式(10)的第2项为$\frac{1}{{n - 1}}\sum\limits_{i = 1}^n a \left({\mathit{\boldsymbol{\hat x}}, {\mathit{\boldsymbol{p}}_i}} \right)$,第3项为0,而要使第2项变小,则要求$a(\mathit{\boldsymbol{\hat x}}, \mathit{\boldsymbol{p}}_{k})$变小,即要求$\mathit{\boldsymbol{\hat x}}$远离非$k$类代表特征$\mathit{\boldsymbol{p}}_{k}$

由上可知,损失函数会拉近测试样本$\mathit{\boldsymbol{\hat x}}$与支持集中同类别代表特征的距离,拉远与支持集中不同类别代表特征的距离, 再由代表特征反馈到支持集嵌入向量,致使整个嵌入空间各类间距离变大。混合损失函数的效果如图 6所示, 左边是原嵌入空间,右边是本文的混合损失函数调整后的嵌入空间。从图 6可以看出,原本相距较近的纽芬兰犬和戈登雪达犬的距离进一步拉大,使得测试样本纽芬兰犬的正确分类更加容易。

图 6 混合损失函数效果示意图
Fig. 6 Diagram of mixture loss function's effects

混合损失函数由一般损失项(交叉熵损失项)和相对误差损失项组成,每项对总损失的贡献程度无需添加额外的权重项学习。各项在总损失中的占比及分配原则为:相对误差损失用于优化相似类别错误分类情况,是辅助任务,占比较小;交叉熵损失项是主任务,占比较大。本文通过限制各项值域的范围达到区分重要性的目的。

交叉熵损失项值域范围为

$ - {\log _2}{P_\theta }(\mathit{\boldsymbol{y}} = \mathit{\boldsymbol{x}}|\mathit{\boldsymbol{\hat x}},\mathit{\boldsymbol{p}}) \in (0, + \infty ) $ (11)

由式(8),得

$ a\left( {\mathit{\boldsymbol{\hat x}},{\mathit{\boldsymbol{p}}_i}} \right) \in (0,1) $ (12)

则相对误差损失项值域范围为

$ \frac{1}{{n - 1}}\sum\limits_{i = 1}^n {\left( {1 - {z_i}} \right)a\left( {\mathit{\boldsymbol{\hat x}},{\mathit{\boldsymbol{p}}_i}} \right)} \in (0,1) $ (13)

$ {z_k}\left( {1 - a\left( {\mathit{\boldsymbol{\hat x}},{\mathit{\boldsymbol{p}}_k}} \right)} \right) \in (0,1) $ (14)

由式(13)(14)可知,相对误差损失项对总损失的贡献有限,不会对总损失造成较大偏差,总损失主要由交叉熵损失贡献,以此达到了区分重要性的目的。

训练即为使用梯度下降法最小化混合损失函数,按照小样本任务要求,从训练集中不重复地随机挑选一个分段样本集,再将一个分段分为支持集和测试集,基于支持集预测测试集标签,并与真实标签比对形成损失,反向传播求梯度更新参数。

2 实验结果与分析

实验是在英特尔Xeon E3-1231V3处理器、32 GB内存、NVIDIA GeForce GTX TITANX 12 GB显卡、Ubuntu15.04操作系统上,通过PyTorch深度学习框架实现。任务都是基于$N$类(way)、$K$个(shot)的小样本学习任务,即提供$N$类没有训练过的新类、每类$k$个标注样本,然后基于该样本集预测测试图片是否属于这$N$类中的某个类。本文在训练时采用与测试时相同的类别数,例如在5-way,5-shot任务上即为使用5类、每类5个样本组成训练分段中的支持集。在距离函数上,本文使用余弦距离计算测试样本嵌入向量与各个类代表特征的距离。

本文使用3个数据集进行实验,包括Omniglot数据集[15]和miniImageNet数据集[22, 28] (2012年ImageNet大规模视觉识别竞赛中的版本(ILSVRC-2012)[35]);另外Cifar100数据集[33]对小样本学习任务比较合适,在该数据集上也做了实验,验证本文方法的能力。

本文使用Adam优化算法[36]对实验模型进行训练,并使用L2正则化方法减少过拟合。初始学习率设为0.001,1 000个分段为一代,每10代学习率减半,共训练100代。

2.1 Omniglot小样本学习任务

Omniglot数据集[15]由50个字母表共1 623类字符组成,每类字符都是20个样本,由不同的人手写而成,该数据集类别多,每类样本少,很适合小样本学习任务。本文采用与匹配网络[28]一样的数据集设置,即重设样本尺寸为$28×28$像素,并且使用旋转90°的样本增强方法,1 200类用于训练,123类用于验证,300类用于测试。

实验结果如表 2所示,可以看出,本文的效果较匹配网络[28]和原型网络[29]都有较大提升,在大部分任务上都有目前所知的最好效果。

表 2 在Omniglot数据集上的小样本分类精度
Table 2 Few-shot classification accuracies on Omniglot dataset

下载CSV
/%
模型 微调 5-way准确率 20-way准确率
1-shot 5-shot 1-shot 5-shot
匹配网络[28] 98.1 98.9 93.8 98.5
匹配网络[28] 97.9 98.7 93.5 98.7
原型网络[29] 98.8 99.7 96.0 98.9
未知模型元学习[23] 98.7±0.4 99.9±0.1 95.8±0.3 98.9±0.2
相关网络[32] 99.6±0.2 99.8±0.1 97.6±0.2 99.1±0.1
代表特征网络 98.6 99.6 96.3 98.9
代表特征网络 99.3 99.7 98.3 99.4

2.2 miniImageNet小样本学习任务

miniImageNet[28]数据集是在匹配网络中首次提出的,包括从ILSVRC-2012[34]数据集中挑选的100类、每类600张,共计60 000张样本,重设尺寸为$84×84$像素。由于匹配网络[28]中没有给出具体的类别名称,采用Ravi等人[22]实验时使用的类别以便比较,64类用于训练,16类用于验证,20类用于测试,本文也使用同样的设置。

miniImageNet数据集在小样本分类任务上基本属于基准数据集,由于使用的图像分类领域的基准数据集ImageNet的数据复杂,具有较高的复杂度,可以很好地检验模型方法的能力,数据分布和每类样本数量又很适合小样本分类任务,所以一经提出即很受欢迎。本文在该数据集上进行了详细的对比实验,验证本文各个优化方法的能力,并且与匹配网络[28]、元学习(LSTM)[22]、未知模型元学习[23]、原型网络、相关网络[32]等近年来的多个主流方法进行对比,如表 3所示。

表 3 在miniImageNet上的小样本分类精度
Table 3 Few-shot classification accuracies on miniImageNet dataset

下载CSV
/%
模型 微调 5-way准确率
1-shot 5-shot
最近邻基准网络[22] 28.86±0.54 49.79±0.79
匹配网络[28] 41.2 56.2
匹配网络[28] 42.4 58.0
匹配网络
FCE[28]
44.2 57.0
匹配网络
FCE[28]
46.6 60.0
元学习
LSTM[22]
43.44±0.77 60.60±0.71
未知模型元学习[23] 48.70±1.84 63.11±0.92
原型网络[29](cosine) 42.48±0.74 51.23±0.63
原型网络[29] 49.42±0.78 68.20±0.66
相关网络[32] 57.02±0.92 71.07±0.69
代表特征网络 51.02±0.68 65.07±0.69
代表特征网络 60.39±0.73 75.12±0.71
注:加粗字体表示最优结果。

需要说明的是,实验时采用的分段中测试集每类的样本数量是10张,而大部分小样本学习工作采用的都是15张。原因在于使用本文提出的代表特征网络比原本的简单CNN结构更耗显存,导致单卡显存不足,而为了保持其他重要的超参数不变,减少了测试样本的数量。实验证明,在原匹配网络[28]中,每类10张测试样本的准确率比15张的准确率低1%左右,本文的实验结果是实际使用每类10张测试样本的实验结果,并没有加上这1%,并且使用5次实验求波动范围作为最终实验结果的方法确保实验结果的准确性。

1) 针对1.1节提出的代表特征网络的对比实验。在原匹配网络[28]上仅修改简单CNN网络为代表特征网络,损失函数仍然使用原交叉熵损失。在5-way,5-shot任务上测试,本文的代表特征网络方法使得分类效果提升了约8%,如表 4所示。

表 4 在miniImageNet数据集上代表特征网络效果
Table 4 Effect of representative feature network on miniImageNet dataset

下载CSV
/%
模型 5-way准确率
1-shot 5-shot
匹配网络[28] 41.2 56.2
代表特征网络(RF) 49.92±0.70 64.27±0.74
注:加粗字体表示最优结果。

分析原因可知,原分类方法基于测试样本嵌入向量与各支持样本嵌入向量距离预测所属类别,而代表特征思想归纳抽象各支持样本嵌入向量为类代表特征向量,再参与计算与测试样本嵌入向量间的距离。由于类代表特征向量比支持向量更具有类代表性,降低了对类中目标不明显支持向量的敏感程度,有效避免了原本基于各支持向量求距离时因目标不明显支持向量造成的错误分类。并且本文使用网络学习代表特征的方式表达原型,充分考虑不同支持向量对原型的不同贡献,得到了一个相比均值更好的原型表达,使得效果提升较为明显。

2) 对不同嵌入模块的效果进行对比。本文在miniImageNet数据集上嵌入不同模块的效果如表 5所示,可以看出,ResNet18直接用在小样本学习任务上反而没有简单的4层CNN效果好。这是由于ResNet18的参数数量很大,超过简单CNN的100倍,而该任务训练数据仅有38 400,并且要对新类进行预测,导致较大程度的过拟合,使得精度在训练集上高于简单CNN,但是测试反而低于简单CNN。此外,参数数量过大会导致训练速度较慢,可以看出,ResNet18的训练耗时是简单CNN的2.74倍。而本文针对小样本任务设计的嵌入模块经过实验验证,无论是效果还是耗时都显著优于ResNet18,并且在比简单CNN耗时略多的情况下,达到了更好的效果。

表 5 在miniImageNet上本文嵌入模块效果
Table 5 Effect of our embedding module on miniImageNet dataset

下载CSV
/%
模型 5-way准确率/% 参数量 速度/(s/epoch)
1-shot 5-shot
代表特征网络(CNN模块+ RF模块) 47.63±0.68 63.07±0.74 0.1×106 634.83
代表特征网络(ResNet18模块+ RF模块) 47.10±0.65 60.14±0.73 11.2×106 1 328.54
代表特征网络(本文嵌入模块+ RF模块) 49.92±0.70 64.27±0.74 1.2×106 842.29
注:加粗字体表示最优结果。

3) 针对1.2节提出的混合损失函数的对比实验。基于代表特征网络,对比混合损失函数与原损失函数的效果。在5-way,5-shot任务上测试,混合损失函数提升了约1%,如表 6所示。

表 6 在miniImageNet上混合损失函数效果
Table 6 Effect of mixture loss function on miniImageNet dataset

下载CSV
/%
模型 5-way准确率
1-shot 5-shot
代表特征网络(RF) 49.92±0.70 64.27±0.74
代表特征网络(RF+ML) 51.02±0.68 65.07±0.69
注:加粗字体表示最优结果。

分析原因可知,对于混合损失函数,更偏向于擅长减小嵌入空间中分布相近类别造成错误分类的情况,而对于相似类别,数据集中存在但数量较少,所以该种优化方法对效果的提升没有代表特征网络明显。

4) 针对测试集上的微调进行对比。在任务定义上,微调指在测试时的一次迭代(一个分段)过程中,仅利用分段中的支持集计算损失,然后反向传播更新网络参数,再使用更新后的网络接受分段中的支持集和测试集输入进行测试,测试后恢复网络到上一次训练后的状态,而后进行下一次的迭代测试,即在整个测试集上仅能针对迭代时相应的分段样本集更新一次参数。不同模型是否进行微调的对比结果如表 3所示,可以看出,在原匹配网络[28]中,微调有3%左右的帮助,而本文方法在微调后,效果提高了近10%,体现了较大优势。

分析原因可知,首先是代表特征网络中跳过连接结构的优势。由于微调时本文利用新类支持集仅更新一次网络参数,而不是迭代多次,而网络在训练几千个分段后已经基本成型,所以网络参数不会有较大改动,不容易过拟合。这时本文考虑的问题是在仅更新一次网络参数的情况下能否有效学习到新类信息。本文对普通网络和跳过连接结构网络在反向传播中梯度的计算进行简单对比。

在普通网络中

$ {\mathit{\boldsymbol{z}}_k} = {\mathit{\boldsymbol{w}}_k}{\mathit{\boldsymbol{a}}_{k - 1}} + {\mathit{\boldsymbol{b}}_k} $ (15)

$ {\mathit{\boldsymbol{a}}_k} = g\left( {{\mathit{\boldsymbol{z}}_k}} \right) = g\left( {{\mathit{\boldsymbol{w}}_k}{\mathit{\boldsymbol{a}}_{k - 1}} + {\mathit{\boldsymbol{b}}_k}} \right) $ (16)

式中,$\mathit{\boldsymbol{z}}_{k}$表示第$k$层线性变换后的输出,$\mathit{\boldsymbol{w}}_{k}$$\mathit{\boldsymbol{b}}_{k}$$k$层的权重参数及偏置项,$\mathit{\boldsymbol{a}}_{k}$$k$层的输出,函数$g$代表相应的激活函数。这时反向传播梯度为

$ \frac{{\partial J}}{{\partial {\mathit{\boldsymbol{a}}_{k - 1}}}} = \frac{{\partial J}}{{\partial {\mathit{\boldsymbol{a}}_k}}}\frac{{\partial {\mathit{\boldsymbol{a}}_k}}}{{\partial {\mathit{\boldsymbol{a}}_{k - 1}}}} = {\left( {{\mathit{\boldsymbol{w}}_k}} \right)^{\rm{T}}}\frac{{\partial J}}{{\partial {\mathit{\boldsymbol{a}}_k}}}{g^\prime }\left( {{\mathit{\boldsymbol{a}}_{k - 1}}} \right) $ (17)

在跳过连接结构网络中

$ {\mathit{\boldsymbol{z}}_k} = {\mathit{\boldsymbol{w}}_k}{\mathit{\boldsymbol{a}}_{k - 1}} + {\mathit{\boldsymbol{b}}_k} $ (18)

$ \begin{array}{*{20}{c}} {{\mathit{\boldsymbol{a}}_k} = g\left( {{z_k} + {\mathit{\boldsymbol{a}}_{k - 1}}} \right) = g\left( {{\mathit{\boldsymbol{w}}_k}{\mathit{\boldsymbol{a}}_{k - 1}} + {\mathit{\boldsymbol{b}}_k} + {\mathit{\boldsymbol{a}}_{k - 1}}} \right) = }\\ {g\left( {\left( {{\mathit{\boldsymbol{w}}_k} + 1} \right){\mathit{\boldsymbol{a}}_{k - 1}} + {\mathit{\boldsymbol{b}}_k}} \right)} \end{array} $ (19)

这时反向传播梯度为

$ \frac{{\partial J}}{{\partial {\mathit{\boldsymbol{a}}_{k - 1}}}} = \frac{{\partial J}}{{\partial {\mathit{\boldsymbol{a}}_k}}}\frac{{\partial {\mathit{\boldsymbol{a}}_k}}}{{\partial {\mathit{\boldsymbol{a}}_{k - 1}}}} = {\left( {{\mathit{\boldsymbol{w}}_k} + 1} \right)^{\rm{T}}}\frac{{\partial J}}{{\partial {\mathit{\boldsymbol{a}}_k}}}{g^\prime }\left( {{\mathit{\boldsymbol{a}}_{k - 1}}} \right) $ (20)

对比式(17)(20)可见,跳过连接结构使得经过的梯度变大,更容易使参数对新样本做出相应调整,而普通结构由于多层级联,梯度层层传递,网络前层的梯度很小,无法做出有效更新。

其次,在测试时仅能进行一次更新的情况下,模型针对相应支持集类求得一个比均值更好的代表特征表达和拉伸各原型间距离都尤为重要。在测试前,仅能基于一个分段中的支持集更新网络,而不是训练时的1 000个分段,因此新类概念严重依赖该支持集,如果其中有部分支持样本目标不明显,就需要本文的代表特征来减少对这些样本的敏感性,从而更接近真实原型。此时如果使用均值原型,由于对目标不明显样本较为敏感,使得距离真实原型较远,容易导致错误分类。如果在测试时可以基于多个分段更新参数,由于每个新类的样本数量增加使得样本向量空间更加稠密,则其质心(均值原型)获得一定的鲁棒性,对目标不明显样本敏感性降低,更容易接近原型。这种情况下,代表特征的优势则不太明显。因此在测试时仅基于一个分段中支持集微调时,代表特征有明显优势。对拉伸距离的分析类似。如图 7所示,本文的分类精度在5-way、1-shot任务上为61.12%,在5-way、5-shot任务上为75.83%,达到了目前已知的最好效果。在5-way、5-shot任务上的分类精度比基于匹配网络[28]方法的分类精度提高了18%,比目前已知的最好方法相关网络[32]高4%。

图 7 在miniImageNet数据集上的5-way, 5-shot和5-way, 1-shot分类任务测试精度
Fig. 7 Testing accuracy on miniImageNet dataset of 5-way, 5-shot and 5-way, 1-shot few-shot classification task ((a)5-way, 5-shot; (b)5-way, 1-shot)

5) 在性能上与相关主流方法进行对比。如表 7所示,匹配网络[28]和原型网络[29]嵌入网络都是使用简单的4层CNN结构,参数量较少(0.1×106),运算复杂度较低,所以效率较高,但效果相对较差。而匹配网络FCE (full context embeddings) [28]使用了LSTM[22],相关网络[32]使用了较为复杂的嵌入网络,运算复杂度较高,效率上低于本文方法。本文方法效率比匹配网络FCE[28]高约8%,比相关网络[32]高约18%。另外,由于元学习LSTM[22]多使用RNN实现,效率上普遍低于度量学习方法,本文方法效率比其高约52%。综上所述,本文方法不仅效果很好,在性能上也有较大优势。

表 7 在miniImageNet上的性能对比
Table 7 Comparison of efficiency on miniImageNet dataset

下载CSV
/%
模型 5-way准确率/% 速度/
(s/episode)
1 -shot 5-shot
匹配网络[28] 41.2 56.2 0.28
匹配网络
FCE[28]
44.2 57.0 0.60
原型网络[29] 49.42±0.78 68.20±0.66 0.29
元学习
LSTM[22]
43.44±0.77 60.60±0.71 1.14
相关网络[32] 57.02±0.92 71.07±0.69 0.67
代表特征网络 60.39±0.73 75.12±0.71 0.55
注:加粗字体表示最优结果。

2.3 Cifar100小样本学习任务

Cifar100数据集[33]有100个类,每个类包含600个尺寸为$32×32$像素的样本,分为500个训练样本和100个测试样本。其中的100个类分成20个大类。每个样本都带有一个fine标签(所属的类)和一个coarse标签(所属的大类)。不同大类间差别较大,如花卉、鱼、家用电器等;相同大类中不同类间差别较小,如大类花卉中有兰花、罂粟花、玫瑰等类。该任务很适合小样本学习。重新组织数据集,将每个类的训练样本和测试样本合并到一起,然后从每个大类中随机挑1个类组成20个类作为验证集,同样从大类中不重复地再挑1个类组成20个类作为测试集,剩下的60个类作为训练集。

同样通过5次实验求波动范围得到实验结果,如表 8所示。由表 8可知,本文方法在Cifar100数据集上也有很好的效果,在5-way,5-shot任务上,与匹配网络相比,本文提高了22%。由此可见,在诸多数据集上都验证了本文方法的效果,表明本文方法具有很好的泛化能力。

表 8 在Cifar100数据集上的小样本分类精度
Table 8 Few-shot classification accuracies on Cifar100 dataset

下载CSV
模型 微调 5-way准确率/%
1-shot 5-shot
匹配网络[28] 53.82±0.71 65.55±0.74
匹配网络[28] 54.52±0.61 67.13±0.71
代表特征网络(代表特征) 55.45±0.65 69.02±0.67
代表特征网络(代表特征+混合
Loss)
56.36±0.72 70.25±0.69
代表特征网络(代表特征+混合
Loss)
72.43±0.68 87.29±0.70
注:加粗字体表示最优结果。

3 结论

本文基于匹配网络[28]训练方法并受原型思想启发提出代表特征网络。1)针对样本特征抽象层次较低的情况,提出改进的嵌入模块,进一步学习样本的高层特征表达。2)针对类中目标不明显样本导致错误分类的情况,提出代表特征模块,使用网络学习求代表特征的方法,降低目标不明显样本对类代表特征的影响,提高类代表特征的鲁棒性,从而完成正确分类。3)针对相似类别分类错误的情况,提出混合损失函数,拉大嵌入空间中各类别间距降低相似类别错分概率。

本文通过实验验证了代表特征网络在各个主流数据集上的优秀表现。但是该方法仍存在优化空间,对于嵌入网络,理论上还存在特征抽象性更好、参数更少的网络,因此设计更好的嵌入网络是该方法的优化方向之一。

参考文献

  • [1] He K M, Zhang X Y, Ren S Q, et al. Deep residual learning for image recognition[C]//Proceedings of 2016 IEEE Conference on Computer Vision and Pattern Recognition. Las Vegas, USA: IEEE, 2016: 770-778.[DOI: 10.1109/CVPR.2016.90]
  • [2] Szegedy C, Liu W, Jia Y Q, et al. Going deeper with convolutions[C]//Proceedings of 2015 IEEE Conference on Computer Vision and Pattern Recognition. Boston, USA: IEEE, 2015: 1-9.[DOI: 10.1109/CVPR.2015.7298594]
  • [3] Krizhevsky A, Sutskever I, Hinton G E. Imagenet classification with deep convolutional neural networks[C]//Proceedings of the 25th International Conference on Neural Information Processing Systems. Lake Tahoe, USA: ACM, 2012: 1097-1105.
  • [4] Weng Y C, Tian Y, Lu D M, et al. Fine-grained bird classification based on deep region networks[J]. Journal of Image and Graphics, 2017, 22(11): 1521–1531. [翁雨辰, 田野, 路敦民, 等. 深度区域网络方法的细粒度图像分类[J]. 中国图象图形学报, 2017, 22(11): 1521–1531. ] [DOI:10.11834/jig.170262]
  • [5] Ren S Q, He K M, Girshick R, et al. Faster R-CNN: towards real-time object detection with region proposal networks[C]//Proceedings of the 28th International Conference on Neural Information Processing Systems. Montreal, Canada: ACM, 2015: 91-99.
  • [6] Redmon J, Divvala S, Girshick R, et al. You only look once: unified, real-time object detection[C]//Proceedings of 2016 IEEE Conference on Computer Vision and Pattern Recognition. Las Vegas, USA: IEEE, 2016: 779-788.[DOI: 10.1109/CVPR.2016.91]
  • [7] Liu W, Anguelov D, Erhan D, et al. SSD: single shot multibox detector[C]//Proceedings of the 14th European Conference on Computer Vision. Amsterdam, The Netherlands: Springer, 2016: 21-37.[DOI: 10.1007/978-3-319-46448-0_2]
  • [8] Zhao W Q, Yan H, Shao X Q. Object detection based on improved non-maximum suppression algorithm[J]. Journal of Image and Graphics, 2018, 23(11): 1676–1685. [赵文清, 严海, 邵绪强. 改进的非极大值抑制算法的目标检测[J]. 中国图象图形学报, 2018, 23(11): 1676–1685. ] [DOI:10.11834/jig.180275]
  • [9] Vaswani A, Shazeer N, Parmar N, et al. Attention is all you need[C]//Proceedings of Advances in Neural Information Processing Systems. California, USA: NIPS, 2017: 5998-6008.
  • [10] Sutskever I, Vinyals O, Le Q V. Sequence to sequence learning with neural networks[C]//Proceedings of Advances in Neural Information Processing Systems. Montreal, Canada: NIPS, 2014: 3104-3112.
  • [11] He D, Xia Y C, Qin T, et al. Dual learning for machine translation[C]//Proceedings of the Advances in 30th Conference on Neural Information Processing Systems. Barcelona, Spain: NIPS, 2016: 820-828.
  • [12] Li Y C, Xiong D Y, Zhang M. A survey of neural machine translation[J]. Chinese Journal of Computers, 2018, 41(12): 2734–2755. [李亚超, 熊德意, 张民. 神经机器翻译综述[J]. 计算机学报, 2018, 41(12): 2734–2755. ] [DOI:10.11897/SP.J.1016.2018.02734]
  • [13] Koch G R, Zemel R, Salakhutdinov R. Siamese neural networks for one-shot image recognition[C]//Proceedings of the 32nd International Conference on Machine Learning. Lille Grande Palais, France: ICML, 2015: #2.
  • [14] Li F F, Fergus R, Perona P. One-shot learning of object categories[J]. IEEE Transactions on Pattern Analysis and Machine Intelligence, 2006, 28(4): 594–611. [DOI:10.1109/TPAMI.2006.79]
  • [15] Lake B, Salakhutdinov R, Gross J, et al. One shot learning of simple visual concepts[C]//Proceedings of the Annual Meeting of the Cognitive Science Society. Boston, USA: CogSci, 2011: 2568-2573.
  • [16] Bengio Y. Deep learning of representations for unsupervised and transfer learning[C]//Proceedings of ICML Workshop on Unsupervised and Transfer Learning. Bellevue, Washington, USA: ICML, 2011: 17-36.
  • [17] Pan S J, Yang Q. A survey on transfer learning[J]. IEEE Transactions on Knowledge and Data Engineering, 2010, 22(10): 1345–1359. [DOI:10.1109/TKDE.2009.191]
  • [18] Luo Z L, Zou Y L, Hoffman J, et al. Label efficient learning of transferable representations acrosss domains and tasks[C]//Proceedings of Advances in Neural Information Processing Systems. Long Beach, USA: NIPS, 2017: 165-177.
  • [19] Yosinski J, Clune J, Bengio Y, et al. How transferable are features in deep neural networks?[C]//Proceedings of Advances in Neural Information Processing Systems. Montreal, Canada: NIPS, 2014: 3320-3328.
  • [20] Dixit M, Kwitt R, Niethammer M, et al. AGA: attribute-guided augmentation[C]//Proceedings of 2017 IEEE Conference on Computer Vision and Pattern Recognition. Honolulu, Hawaii, USA: IEEE, 2017: 3328-3336.[DOI: 10.1109/CVPR.2017.355]
  • [21] Hariharan B, Girshick R. Low-shot visual recognition by shrinking and hallucinating features[C]//Proceedings of 2017 IEEE Conference on International Computer Vision. Venice, Italy: IEEE, 2017: 3037-3046.[DOI: 10.1109/ICCV.2017.328]
  • [22] Ravi S, Larochelle H. Optimization as a model for few-shot learning[C]//Proceedings of the 5th International Conference on Learning Representations. Toulon, France: ICLR, 2017.
  • [23] Finn C, Abbeel P, Levine S. Model-agnostic meta-learning for fast adaptation of deep networks[J]. arXiv: 1703.03400, 2017.
  • [24] Santoro A, Bartunov S, Botvinick M, et al. Meta-learning with memory-augmented neural networks[C]//Proceedings of the 33rd International Conference on Machine Learning. New York, USA: ICML, 2016: 1842-1850.
  • [25] Cheng G, Zhou P C, Han J W. Duplex metric learning for image set classification[J]. IEEE Transactions on Image Processing, 2018, 27(1): 281–292. [DOI:10.1109/TIP.2017.2760512]
  • [26] Cheng G, Yang C Y, Yao X W, et al. When deep learning meets metric learning:remote sensing image scene classification via learning discriminative CNNs[J]. IEEE Transactions on Geoscience and Remote Sensing, 2018, 56(5): 2811–2821. [DOI:10.1109/TGRS.2017.2783902]
  • [27] Graves A, Wayne G, Danihelka I. Neural turing machines[EB/OL].[2018-10-28]. https://arxiv.org/pdf/1410.5401.pdf.
  • [28] Vinyals O, Blundell C, Lillicrap T, et al. Matching networks for one shot learning[C]//Proceedings of Advances in Neural Information Processing Systems. Barcelona, Spain: NIPS, 2016: 3630-3638.
  • [29] Snell J, Swersky K, Zemel R. Prototypical networks for few-shot learning[C]//Proceedings of Advances in Neural Information Processing Systems. California, USA: NIPS, 2017: 4077-4087.
  • [30] Fort S. Gaussian prototypical networks for few-shot learning on Omniglot[EB/OL].[2018-10-28]. https://arxiv.org/pdf/1410.5401.pdf.
  • [31] Mehrotra A, Dukkipati A. Generative adversarial residual pairwise networks for one shot learning[EB/OL].[2018-10-28]. https://arxiv.org/pdf/1410.5401.pdf.
  • [32] Sung F, Yang Y X, Zhang L, et al. Learning to compare: relation network for few-shot learning[EB/OL].[2018-10-28]. https://arxiv.org/pdf/1410.5401.pdf.
  • [33] Krizhevsky A, Nair V, Hinton G. The CIFAR-10 dataset (Canadian institute for advanced research)[EB/OL].[2018-10-28]http://www.cs.toronto.edu/kriz/cifar.html.
  • [34] Banerjee A, Merugu S, Dhillon I S, et al. Clustering with Bregman divergences[J]. The Journal of Machine Learning Research, 2005, 6: 1705–1749.
  • [35] Russakovsky O, Deng J, Su H, et al. Imagenet large scale visual recognition challenge[J]. International Journal of Computer Vision, 2015, 115(3): 211–252. [DOI:10.1007/s11263-015-0816-y]
  • [36] Kingma D P, Ba J. Adam: a method for stochastic optimization[EB/OL].[2018-10-28] https://arxiv.org/pdf/1412.6980.pdf.