MAML元学习算法是小样本学习领域中的经典方法,本文将重点讲解该方法的理论和飞桨代码实现。本章首先对小样本学习的问题定义、评价标准和常用数据集进行介绍,以期读者对本领域概况获得基本了解。
问题定义
假设数据集中包含个类别,将这个数据集按类别划分为不相交的两部分,一部分称为基础集(Baseset),一部分称为新颖集(NovelSet),其中,,且。模型在基础集上离线训练,以获得所需的先验知识和特征提取能力。在基础集上的具体训练方式,因算法的不同而异。对新颖集随机采样N个类别,每个类别采样K个样本,这个带标签的样本构成支持集(Supportset)S,小样本学习任务便是在这个很小的数据集S上进行,称为N-wayK-shot任务。此外,对这N个类别再采样个无标签样本构成查询集Q,在Q上进行小样本模型的分类测试。通常,N=5,K=1或5。
本文所述的MAML算法,是在基础集上以相同方式构建了若干个N-wayK-shot训练任务,进行离线训练。这种训练方式是一种元学习的训练方法,保持了与测试过程相同的任务构建流程,能够最大程度避免协变量偏移。
为了更清晰地展示数据集的划分方法,这里以miniImageNet数据集[1]为例,进行图形化展示,如下图所示。该数据集共有100个类别,每个类别各有600张图像样本。
评价标准
在600轮(或1000轮等等)不同的N-wayK-shot任务上,分别进行小样本学习,得到在查询集上的top-1分类准确率。最终的评估指标是这600个任务上的平均准确率和置信区间。
常用数据集
miniImageNet[1]:由OriolVinyals等在MatchingNetworks[1]中首次提出。在MatchingNetworks中,作者提出对ILSVRC-12中的类别和样本进行抽取(参见其AppendixB),形成了一个数据子集,将其命名为miniImageNet,包含100类共60000张彩色图片,其中每类有600个样本,图像大小为84×84。随后,普林斯顿大学的博士生SachinRavi[2]将该数据集随机划分为64个基础集类,16个验证集类和20个新颖集类。
下载链接:
tieredImageNet[3]:同样是ILSVRC-12的子集,包含ImageNet中层次结构较高级别的34个大类(category),每个大类包含10~30个小类(class)。该数据集中各子集的划分方法如下表所示。
FC100[4]:即Fewshot-CIFAR100,截取自CIFAR100数据集,共包含100个类别,每个类别600张图片,图像大小为32×32×3。其中基础集60个类别,验证集和新颖集各60个类别。
Omniglot[5]:包含50个不同的字母表,每个字母表中的字母各包含20个手写字符样本,每一个手写样本都是不同人通过亚马逊MechanicalTurk在线绘制的。Omniglot数据集的多样性强于MNIST数据集,常用于小样本识别任务。
CUB[6]:该数据集是一个细粒度数据集,全部由鸟类图片构成,共包含200个类别,其中100个类别为基础集,50个类别为验证集,50个类别为新颖集。
MAML模型算法
模型无关元学习(Model-AgnosticMeta-Learning,简称MAML)算法[7],其模型无关体现在,能够与任何使用了梯度下降法的模型相兼容,广泛应用于各种不同的机器学习任务,包括图像分类、目标检测、强化学习等。元学习的目标,是在大量不同的任务上训练一个模型,使其能够使用极少量的训练数据(即小样本),进行极少量的梯度下降步数,就能够迅速适应新任务,解决新问题。
模型方法
MAML算法的训练目的是获得一组最优的初始化参数,使得模型能够快速适配(fastadaptation)新任务。作者认为,某些特征比另一些特征更容易迁移到其他任务中,这些特征具有跨任务间的通用性。既然小样本学习任务只提供少量标记样本,模型在小样本上多轮迭代训练后必然导致过拟合,那么就应该尽可能使模型只迭代训练几步。这就要求模型已经具有广泛适配于各种任务的初始化参数,这组参数应包含模型在基础集上所学到的先验知识。
假设模型可以用函数θ表示,θ为模型参数。适配新任务时,模型通过梯度下降法迭代一步(或若干步),参数θ更新为θ,即θθαθ
其中,α为超参数,用于控制适配过程的学习率。
在多个不同任务上,模型通过计算θ的损失来评估模型参数θ。具体地,元学习的目标是获得一组参数θ,使得模型在任务分布上,能够快速适配所有任务,使得损失最小。用公式表达如下:
通过随机梯度下降(SGD)法,模型参数θ按照以下公式进行更新:
这里需要注意,我们最终要优化的参数是θ,但计算损失函数却是在微调后的参数θ上进行,训练过程可通过下图示意。
由于上述元学习算法在损失计算和优化参数方面的特点,训练包括了两层循环。外层循环是元学习过程,通过在任务分布上采样一组任务,计算在这组任务上的损失函数;内层循环是微调过程,即针对每一个任务,迭代一次(或若干次)梯度下降,将参数进行更新为θ,然后计算在参数为θ时的损失。梯度反向传递时,需要跨越两层循环传递到初始参数θ上,完成元学习的参数更新。
完整的MAML算法如下图所示。
实验结果
在Omniglot和miniImageNet数据集上,文献给出的实验结果如下图所示。
飞桨实现
本小节给出本人在“飞桨论文复现挑战赛(第三期)”中完成的部分关键代码。完整项目代码已在GitHub和AIStudio上开源,欢迎读者star、fork。链接如下:
GitHub地址:
AIStudio地址:
关键代码实现
该模型比较特殊,梯度需要穿过内外两层循环传递到原始参数。如果基于nn.Layer类进行常规的模型搭建,在内循环更新梯度时,模型参数会被覆盖,导致初始参数丢失。得益于飞桨动态图模式灵活组网的特点,本项目将模型参数和算子分离设计,在外循环中保存原始参数副本θ;内循环中通过该副本更新参数,计算损失函数。计算图通过动态图模式自动构建,最终将梯度反传回原始参数θ。
MAML类的代码如下:
元学习器类的代码如下:
复现结果
本项目在Omniglot数据集上进行了实验复现,其复现的结果如下表所示:
小结
参考文献
[1]VinyalsO,BlundellC,LillicrapT,etal.MatchingNetworksforOneShotLearning[J],2016.
[2]RaviS,LarochelleH.Optimizationasamodelforfew-shotlearning[J],2016.
[3]RenM,TriantafillouE,RaviS,etal.Meta-learningforsemi-supervisedfew-shotclassification[J].arXivpreprintarXiv:1803.00676,2018.
[4]OreshkinBN,RodriguezP,LacosteA.Tadam:Taskdependentadaptivemetricforimprovedfew-shotlearning[J].arXivpreprintarXiv:1805.10123,2018.
[5]LakeB,SalakhutdinovR,GrossJ,etal.Oneshotlearningofsimplevisualconcepts[C].Proceedingsoftheannualmeetingofthecognitivesciencesociety,2011.
[6]WahC,BransonS,WelinderP,etal.Thecaltech-ucsdbirds-200-2011dataset[J],2011.
[7]FinnC,AbbeelP,LevineS.Model-agnosticmeta-learningforfastadaptationofdeepnetworks[C].InternationalConferenceonMachineLearning,2017:1126-1135.