基于双提示的排练式持续学习方法:DUPT

学术背景

在机器学习和神经网络领域,持续学习(Continual Learning)是一个重要的研究方向。持续学习的目标是让模型能够在一系列任务中不断学习新知识,同时避免遗忘已经学到的旧知识。然而,现有的持续学习方法面临一个主要挑战:灾难性遗忘(Catastrophic Forgetting)。灾难性遗忘指的是模型在学习新任务时,会迅速遗忘之前学到的知识,导致旧任务的性能大幅下降。这一问题在现实应用中尤为突出,因为许多任务需要模型在不断变化的环境中持续学习和适应。

为了解决这一问题,研究者们提出了多种方法,其中基于回放的方法(Rehearsal-based Methods)是一种常见的解决方案。这类方法通过存储旧任务的代表性样本,并在学习新任务时回放这些样本来巩固旧知识。然而,现有的回放方法存在两个主要问题:1)模型在学习新任务时,由于样本数量有限,其泛化能力较弱;2)知识蒸馏(Knowledge Distillation)虽然可以传递旧知识,但过强的约束可能会限制模型学习新知识的能力。

为了缓解这些问题,南京信息工程大学、南京林业大学、东南大学和南京邮电大学的研究团队提出了一种基于双提示的持续学习方法,称为DUPT。该方法通过引入输入感知提示(Input-aware Prompt)和代理特征提示(Proxy Feature Prompt),从输入和特征两个层面增强模型的泛化能力和知识传递效率。

论文来源

这篇论文由Shengqin JiangDaolong ZhangFengna ChengXiaobo LuQingshan Liu共同撰写。作者分别来自南京信息工程大学计算机学院、南京林业大学机电工程学院、东南大学自动化学院和南京邮电大学计算机学院。论文于2025年发表在Neural Networks期刊上,题目为《DUPT: Rehearsal-based Continual Learning with Dual Prompts》。

研究流程

1. 输入感知提示(Input-aware Prompt)

在持续学习的过程中,新任务的样本数量通常较少,这限制了模型的泛化能力。为了解决这一问题,DUPT引入了输入感知提示,通过动态扩展输入分布,帮助模型更好地捕捉新任务的样本特征。

具体来说,输入感知提示通过以下步骤生成: 1. 输入数据预处理:将输入图像下采样至16×16分辨率,以减少计算复杂度。 2. 注意力机制:将下采样后的图像输入一个冻结的注意力模块,生成注意力向量。 3. 权重生成:将注意力向量通过全连接层,生成与提示池中提示数量相等的权重向量。 4. 提示生成:将权重向量与提示池中的提示进行加权求和,生成最终的输入感知提示。

输入感知提示的优势在于,它能够利用有限的提示生成多样化的输入分布,从而增强模型的泛化能力。

2. 代理特征提示(Proxy Feature Prompt)

在持续学习中,旧知识的传递通常通过知识蒸馏实现。然而,直接对齐新旧模型的特征可能会限制模型学习新知识的能力。为了解决这一问题,DUPT引入了代理特征提示,通过构建可学习的中间特征表示,缓解特征冲突。

具体来说,代理特征提示的生成过程如下: 1. 提示池初始化:初始化一个包含固定数量提示的提示池。 2. 特征提取:将提示池中的提示分别输入卷积层和全连接层,生成可学习的提示。 3. 知识蒸馏:通过优化目标函数,约束当前模型特征与代理特征提示之间的差异,同时保持代理特征提示与旧模型特征的一致性。

代理特征提示的优势在于,它能够避免新旧模型特征之间的直接对齐,从而在保持旧知识的同时,增强模型学习新知识的能力。

3. 优化目标

DUPT的优化目标包括以下几个部分: 1. 交叉熵损失:用于优化当前任务的数据。 2. 回放交叉熵损失:用于优化回放缓冲区中的旧任务数据。 3. 回放对数蒸馏损失:用于约束当前模型与旧模型在回放数据上的输出差异。 4. 特征蒸馏损失:用于约束当前模型特征与代理特征提示之间的差异。

通过联合优化这些目标,DUPT能够在持续学习过程中同时增强模型的稳定性和可塑性。

主要结果

DUPT在多个数据集上进行了实验,包括CIFAR10、CIFAR100和TinyImageNet。实验结果表明,DUPT在持续学习任务中表现优异,尤其是在缓冲区较小的情况下,DUPT的性能显著优于现有方法。

  1. CIFAR10数据集:在缓冲区大小为200的情况下,DUPT将DER++的平均准确率提高了4.92%。
  2. CIFAR100数据集:在缓冲区大小为500的情况下,DUPT将DER++的平均准确率提高了3.41%。
  3. TinyImageNet数据集:在缓冲区大小为4000的情况下,DUPT将DER-BFP的平均准确率提高了0.82%。

此外,DUPT还展示了与现有方法的兼容性。当与最新的DER-BFP方法结合时,DUPT在CIFAR10和CIFAR100数据集上分别实现了1.30%和1.34%的性能提升。

结论

DUPT通过引入输入感知提示和代理特征提示,从输入和特征两个层面增强了持续学习模型的泛化能力和知识传递效率。实验结果表明,DUPT在多个数据集上均表现出色,尤其是在缓冲区较小的情况下,DUPT的性能显著优于现有方法。此外,DUPT的兼容性使其能够与现有的持续学习方法无缝集成,进一步提升了性能。

研究亮点

  1. 双提示机制:DUPT通过输入感知提示和代理特征提示,从输入和特征两个层面增强了模型的泛化能力和知识传递效率。
  2. 显著性能提升:在缓冲区较小的情况下,DUPT在多个数据集上均实现了显著的性能提升。
  3. 兼容性强:DUPT能够与现有的持续学习方法无缝集成,进一步提升了性能。

未来展望

尽管DUPT在持续学习任务中表现优异,但仍有一些问题需要进一步探索。首先,在缓冲区较小的情况下,DUPT的性能仍然落后于缓冲区较大的情况。如何更有效地表示旧知识仍然是一个开放性问题。其次,DUPT依赖于从头训练的模型,这些模型在小数据集上容易过拟合。未来的研究可以探索如何利用预训练模型来缓解这一问题。

DUPT为持续学习提供了一种有效的解决方案,具有重要的科学价值和应用前景。