在线持续学习是深度学习领域的一个重要研究方向,旨在解决传统深度学习方法在处理动态数据时的局限性。传统的深度学习算法通常在静态的数据集上进行离线训练,将训练后的模型部署到实际应用中,并在部署后不再更新模型。然而,现实世界中的数据往往与训练数据集不完全相同,且是动态变化的。新的数据不断产生,旧的数据可能过时或失效。传统的静态模型往往无法适应这种变化。
举个例子来说,当我们将源域上训练的深度神经网络部署到测试环境,即目标域时,目标域上的模型性能会因为域偏移而恶化。在自动驾驶中,一个训练完成的模型可能由于天气、运行区域、传感器等的不同而在测试时表现出显著的性能下降。实际应用中,测试集与训练集不完全匹配是十分普遍的。并且由于难以预知环境将会发生怎样的变化,需要在与环境不断交互的过程当中,学习到在新的场景下更准确的数据表示,从而保持任务的性能。这就是一个在线持续域适应的问题。
片上持续在线学习是SNN的一个具有潜力的应用场景。
基于小批量/单样本的在线持续域适应面临一系列的挑战。首先,单样本更新导致模型的归一化统计参数难以准确估计。模型的归一化统计参数是指,由一个批量中所有样本的特征计算出的均值与方差。批归一化统计(BatchNorm)是目前深度学习中非常普遍的一种做法,可以缓解梯度爆炸/消失、减少过拟合风险、增加模型收敛速度、减少模型对于训练参数的依赖等。基于少量/单个样本计算出的统计信息是有偏的,而基于有偏的统计参数计算的归一化值会不准确,导致模型性能下降。同时,基于单个样本产生的损失更新模型会导致模型更新不稳定。由于样本间质量参差不齐,不同样本的特征很可能有较大差异,导致产生的模型更新方向不稳定。
本章节中,介绍一些小批量/单样本更新的方法。
图1重新校准BN层统计参数,并仅更新仿射参数[1]
式(1):BN层特征计算公式;式(2):BN层统计参数均值与方差的计算公式
实际上,仅是通过BN层统计参数的重计算便能大幅提升自适应的性能。图2展示了不同情况下图像的特征分布。方法(a)Source是采取源域统计量与源模型仿射参数,方法(b)BN是采取目标域统计量与源模型仿射参数。结合式(1)来看,方法(b)作归一化时,其μ与σ来自目标域数据,而γ与β来自源域。方法(c)Tent是采取目标域统计量,并基于最小化熵更新仿射参数。即在方法(b)基础上,基于目标域数据更新γ与β。方法(d)Oracle采用目标域数据监督训练模型,是我们所期望的特征分布。
四张子图中,最后侧的黄色部分均是没有噪声的图像在源域模型上的特征分布。如图2(a)所示,如果直接将目标域数据应用到源域模型上,模型提取的特征分布会与未加噪的数据有较大的不同,从而造成性能的下降。如图2(b)所示,方法(b)调整了带噪数据特征分布的位置和宽窄。可以发现仅是调整BN层参数后,带噪数据的特征分布就已大幅接近目标域数据监督训练的结果。方法(a)的错误率可能高达80%多,而调整统计参数为目标域数据统计参数后,可以把错误率降到20%左右。方法(c)在(b)的基础上,对数据分布的形状也做出了一定的调整,可以令错误率再下降大概1-3个百分点。
图2带有高斯噪声的CIFAR100-C图像的特征分布[1]
实际上目前以基于BN的模型为预训练网络进行适应的方式,都极大依赖于BN层参数的调整。但是这种方式的缺点也显而易见。因为要对统计参数有一个校准的估计值,要求一个batch的中有较多的样本。如表1所示,当batchsize为32的时候,Tent准确率为85.5。但是batchsize降低到16、8的时候,准确率会降低到35.1、16.7。同时可以观察到,在batchsize为16、8的时候,发生了灾难性遗忘。由于较小的batchsize估计的统计参数不准,仅使用未加任何修改的Tent的话,是难以在小样本、单样本上进行适应的。这实际上不符合现实生活中数据流式输入的场景。但是由于目前的主流预训练模型都是基于BN实现的,后续很多实现小样本、单样本自适应的方法其实是基于改进的BN统计参数估计方法。
表1不同batchsize下Tent的性能[2]
如表2所示,在将原始BN模块替换为改进后的模块(MECTA)[2]后,将Tent在batchsize为16时的准确率从35提升到了71。当然MECTA中除了用到所介绍的自适应滑动更新BN层的方法外,还采取了稀疏剪枝、按需训练一类的策略,提升效果是共同作用下的结果。
但是这种方法实际还是依赖于小批量的统计数据分布,会导致当batchsize进一步减小时,当前小批量的计算的分布统计参数漂移大,导致性能下降。但是其提出的从数据中动态估计滑动平均参数以及剪枝的思想是可以借鉴的。
表2MECTA对于Tent的提升作用[2]
2.利用数据增广估计单样本的BatchNorm统计参数
另一种方法是利用数据增广估计单样本的BN参数。如图3所示,谷歌斯坦福在2022年提出AugBN[3],借助单样本的多个增强估计了单样本的BN层统计参数。它实际上是对一个样本施加多次数据增强,然后用原样本与增强后的样本一起计算BN统计参数,再与源模型参数加权平均。前面介绍的的MECTA是在batch的维度上面滑动平均,而AugBN是在源域统计数据单样本和多个增强的统计数据上面加权平均。
图3AugBN:利用数据增广估计BN层统计参数[3]
由于数据增强样本的分布难以控制,所以不是为所有的增广样本分配与原样本相同的权重,把增广样本的权重设置为1/2n,其中n为数据增强数目。实际实验时,n=2,即对单样本进行两次增强;k=5,m=5,就是说每次用五种数据增强的组合作用到x上。
由于AugBN需要跑次不同的先验值,实际使用时,,再用熵最小的top3结果进行投票。就是说,AugBN实际只解决了流式样本更新时每次只有单个样本可用的问题。但实际上它需要有多次的前向过程,增加了推理过程的计算量。但是之所以需要那么多次前向过程,是因为其本身是一种非参数化的方法,需要依据多次迭代的投票找出分类。如果结合在线适应的一些无监督loss可能可以减少前向的次数。表3展示了AugBN在各数据集上的性能,在分类方面,与直接使用源模型相比,AugBN在CIFAR-10-C上取得了17%的相对提升,并且与现有的方法相比也有不错的表现。但在ImageNet-C上的准确率仍仅有25%左右。
表3AugBN在各数据集上性能
3.利用InstanceNorm修正BatchNorm统计参数
NIPS2022年发表的NOTE提出了一种利用instancenorm(IN)修正batchnorm值的方法[4]。在介绍NOTE之前,我们来大致了解一下不同的norm方法。如图4所示,BatchNorm是针对一个channel计算当前channel的均值、方差进行标准化,Layernorm是针对单个样本的所有channel进行均值、方差的计算,InstanceNorm是针对单个样本、单个channel的特征图进行标准化,GroupNorm是针对单个样本的成组特征进行标准化。
图4不同的归一化方法图示[6]
图5非独立同分布流式样本示意图[4]
具体来说,NOTE提出的均值、方差估计公式如式(4)所示。
表4NOTE在各数据集上性能
图6不同样本分布偏移程度与batchsize下各算法的性能
从前面几种方法介绍中也可以看出,基于BN层的自适应有以下几个问题:
基于此,作者将三种策略用于SAR中以提升基于GN进行自适应的性能:
不同方法在ImageNet-C(severity5)上单样本适应的性能如表5所示,总体上来说,SAR在达到较高准确率的同时,具有较低的复杂度,并且不需要额外的数据。与方法[3]比,在ImageNet-C上达到了更高的准确率。
表5不同方法在ImageNet-C(severity5)上单样本适应的性能
本文主要介绍了目前小批量/单样本在线自适应的一些挑战和可能的解决方法。具体来说,介绍了
综合而言,在标签和输入分布难以预知的情况下,基于流式输入进行稳定、在线的学习并避免灾难性遗忘,仍然是深度学习领域的一个复杂且重要的问题。在线持续学习为我们提供了机会去构建更加灵活、智能的模型,以应对不断变化的现实世界需求,推动算法落地于实际生活中。
[1]D.Wang,E.Shelhamer,S.Liu,B.Olshausen,andT.Darrell,“Tent:FullyTest-timeAdaptationbyEntropyMinimization.”arXiv,Mar.18,2021.doi:10.48550/arXiv.2006.10726.
[3]A.Khurana,S.Paul,P.Rai,S.Biswas,andG.Aggarwal,“SITA:SingleImageTest-timeAdaptation.”arXiv,Sep.07,2022.doi:10.48550/arXiv.2112.02355.
[4]T.Gong,J.Jeong,T.Kim,Y.Kim,J.Shin,andS.-J.Lee,“NOTE:RobustContinualTest-timeAdaptationAgainstTemporalCorrelation.”arXiv,Jan.11,2023.doi:10.48550/arXiv.2208.05117.
[5]S.Niuetal.,“TowardsStableTest-TimeAdaptationinDynamicWildWorld.”arXiv,Feb.23,2023.doi:10.48550/arXiv.2302.12400.