LearnorRecallRevisitingIncrementalLearningwithPre-trainedLanguageModels
收录会议:
ACL2024,LongPaper,Oral
论文链接:
背景
增量学习(IL)一直是计算机视觉和自然语言处理(NLP)领域长期存在的问题。近年来,随着大语言模型(LargeLanguageModel,LLM)在各种NLP下游任务中取得了显著进展,将LLMs作为骨干网络在NLP领域的增量学习研究中已成为一种常见做法。
大多数研究假设灾难性遗忘是实现优越增量学习性能的最大障碍,并提出了各种技术来克服这一问题。然而,我们发现这一假设存在问题。
这些发现促使我们重新审视基于LLMs的增量学习,并鼓励未来的研究更加深入地理解LLMs中的灾难性遗忘问题。
新发现
我们利用探测技术probing评估模型backbone对目标任务的表示能力,实现如图1所示。
▲图1Probing实验图
新发现1:大模型在连续学习过程中并没有丢失其知识
我们在实验中使用生成模型进行类别增量意图分类的观察和探测性能。图2(a)显示,随着更多新任务的学习,观察到的性能显著下降,从约98%降至10%,这一结果符合我们对灾难性遗忘的理解。
然而,图2(b)描述了一个完全不同的现象。LLMs在学习第一个任务后就达到了很高的探测性能,并且从第二个任务开始,线性探测性能几乎没有下降。换句话说,即使LLMs仅按顺序适应新任务(Sequentialfine-tuning,SEQ),它们依然保留了分类所有15个任务的知识。这个现象与我们对灾难性遗忘和SEQ的理解相矛盾。
实际上,探测性能之所以很高,是因为在训练探测分类器时,所有任务的数据都是可用的,而观察到的性能较差,是因为原始分类器仅在当前任务的数据上进行训练。
因此,经过探测的实验结果表明大模型在连续学习过程中并没有丢失其知识。
(a)观测表现
(b)线性探测表现
新发现2:Probing性能:Linear>CosineLinear≈CosinePrototype>Prototype
我们发现四个探测指标的排序如下:Linear>CosineLinear≈CosinePrototype>Prototype。如图3所示:
(a)线性探测
(b)余弦探测
(c)原型探测
(d)余弦原型
▲图3四种探测指标情况
首先,我们需要分别理解LLMs的特征(即最后的隐藏状态)、词向量和探测分类器中的类别嵌入“是什么样的”。特征、词向量和类别嵌入的L2范数和余弦相似度的直方图如图4。
(a)特征相似度
(b)特征-词嵌入相似度
(c)特征范数
(d)词嵌入范数
▲图4Pythia-410m的特征和不同嵌入的直方图
图4a显示,特征在向量空间中占据一个狭窄的圆锥形区域,而不是在所有方向上均匀分布。更令人惊讶的是,图4b显示,学习到的(输出)词向量与特征几乎是正交的。我们推测,交叉熵损失函数鼓励除了真实标签外的所有词向量在预训练过程中远离特征。
换句话说,交叉熵损失鼓励logits之间有较大的差异,并且词向量与特征正交,以便更好地区分logits。因此,考虑到词向量层本质上是一个线性层,线性探测有最佳表现也就不足为奇。
从这个角度来看,原型探测表现较差也就不奇怪,因为原型(类别特征中心)也落在这个狭窄的圆锥空间内,而这对于区分logits并不是一个最优的解决方案。
那么,为什么余弦归一化会降低线性探测的性能,但能改善原型探测的性能呢?图4c和图4d展示了特征和词向量的L2范数。我们发现,词向量的范数与特征相比存在较大的差异。这表明,词向量的范数包含了来自预训练阶段的先验知识。
因此,余弦线性探测忽略了特征范数的差异,因此相比于线性探测,其性能较差。对于原型探测,原型位于一个狭窄的圆锥空间中,原型和特征之间的相似度较大,且接近彼此。在这种情况下,余弦归一化可以消除范数的干扰,从而建立logits和特征之间余弦相似度的关系。
新发现3:LLMs抵抗遗忘的关键在于Transformer的结构和预训练获取的知识
我们评估了在不同预训练步数的检查点上的线性探测性能:{0,16,128,1k,10k,143k(最终)}。我们加载预训练的检查点(或在步数为0时随机初始化的检查点),并在使用SEQ进行增量学习前后评估它们的线性探测性能。
图5展示了预训练中的两个主要阶段:过拟合和泛化。在第一个阶段(步数0-步数128),模型开始记忆预训练语料库,线性探测性能下降。在第二个阶段(步数1k-步数143k),模型逐渐学习预训练知识,线性探测性能上升。
然而,当模型进一步泛化到预训练语料库时(步数10k-步数143k),小型骨干网络(如Pythia-70m和160m)的线性探测性能再次下降,原因是预训练和下游任务之间存在差距。这个差距可以通过适应下游任务来消除。
对于较大的骨干网络(如Pythia-410m、1b和1.4b),模型能够直接适应新任务,而不会受到这种差距的影响。此外,我们还有以下有趣的发现:
(c)关系抽取(BeforeSEQ)
(d)余弦原型关系抽取(AfterSEQ)
▲图5不同训练步骤的检查点的线性探测性能
1.预训练确实改善了增量学习中的线性探测性能(见图5b和图5d)。
2.除了预训练之外,Transformer的架构也是在SEQ过程中获得高线性探测准确率的关键因素。当下游任务相对简单时,例如意图分类,即使是随机初始化的模型也能获得较高的线性探测性能(见图5b)。而当下游任务较为复杂时,例如关系抽取(见图5d),预训练则带来了显著的性能提升。
3.更令人惊讶的是,SEQ提高了几乎所有预训练步骤的模型的线性探测性能(见图5a与5b;图5c与5d)。这表明,Transformer的架构即使仅在新任务上进行顺序微调,也能够逐步吸收新知识。
新发现4:真正的遗忘发生于分类器中
我们观察到,在SEQ模型中,新类别的logits远大于旧类别的logits。由于特征和类别嵌入决定了logits的大小,而特征占据一个狭窄的圆锥空间,其范数相对接近,因此我们可以推测,遗忘现象的发生是由以下原因之一引起的:
(1)类别嵌入的范数,或(2)特征与类别嵌入之间的余弦相似度。对于第一种原因(即类别范数),我们在图6a和图6b中比较了学习的线性分类器和线性探测分类器之间的类别嵌入范数。
令人惊讶的是,在SEQ的观察分类器中,新任务的类别嵌入范数并不大于旧任务的类别嵌入范数。这表明,类别范数不是SEQ中遗忘现象的主要原因。
对于第二个原因(即余弦相似度),我们在图6c和图6d中比较了观察分类器和探测分类器之间类别嵌入的移动距离。任务t的类别嵌入在任务时的移动距离计算如下:
1.当模型完成任务的训练后,我们计算任务t的所有类别嵌入与所有任务的类别特征中心之间的余弦距离,并得到一个余弦相似度矩阵。
2.当模型完成任务t+k的训练后,我们计算任务t的所有类别嵌入与所有任务的类别特征中心之间的余弦距离,并得到一个余弦相似度矩阵。
3.然后,任务t的类别嵌入的移动距离计算为余弦相似度矩阵和之间的平均绝对差异。移动距离衡量了自学习以来,类别嵌入相对于所有类别特征中心的移动情况。
(a)观测分类器范数
(b)探测分类器范数
(c)观测分类器移动距离
(d)探测分类器移动距离
▲图6在SEQ过程中观察到的线性分类器与线性探测分类器的比较
如果分类器没有遗忘某个类别,那么它的类别嵌入到所有类别特征中心的距离应该保持恒定。换句话说,如果分类器没有遗忘如何使用LLMs提取的特征来分类该类别,则其移动距离将为零。
图6c和6d显示,观察分类器的类别嵌入相对于探测分类器发生了显著变化。这表明,遗忘现象的发生是因为旧的类别嵌入被推离了其初始和最优位置。
提出新方法SEQ*
最后,我们根据实验发现设计了SEQ,提出了以下策略来缩小SEQ中探测和观察性能之间的差距:(S1)Warm-up后冻结LLMs;(S2)在学习新任务时冻结旧分类器;(S3)只有在CIL场景中没有旧数据可用的情况下才使用余弦线性分类器。否则,请使用线性分类器;(S4,可选)预先分配未来的分类器。
我们将使用上述策略的方法称为SEQ,如图7所示。实验结果如图8所示。具体实验情况详见论文:
(S1)Warm-up后冻结LLMs
(S2)在学习新任务时冻结旧分类器
(S3)使用正确的分类器
(S4)预先分配未来的分类器
▲图7对SEQ*的描述
▲图8在句子级分类任务上SOTA方法和SEQ*的比较
更多阅读
#投稿通道#
让你的文字被更多人看到
如何才能让更多的优质内容以更短路径到达读者群体,缩短读者寻找优质内容的成本呢?答案就是:你不认识的人。
总有一些你不认识的人,知道你想知道的东西。PaperWeekly或许可以成为一座桥梁,促使不同背景、不同方向的学者和学术灵感相互碰撞,迸发出更多的可能性。