跳转至

新进展!Larimar - 让大型语言模型像人一样记忆与遗忘

1. 论文摘要

更新大型语言模型(LLM)中的知识是当前研究的一个重要挑战。本文介绍了 Larimar——一种受大脑启发的新架构,它通过分布式情节记忆来增强 LLM。Larimar 的记忆系统能够在不需要重新训练或微调的情况下,动态地进行一次性知识更新。在多个事实编辑基准测试中,Larimar 展示了与最有竞争力的基线相当的精度,即使在连续编辑的挑战性环境中也是如此。它在速度上也超过了基线,根据不同的 LLM,可以实现 4 到 10 倍的加速。此外,由于其架构的简单性、LLM 不可知论和通用性,Larimar 也展示出了灵活性。我们还提供了基于 Larimar 的一次性记忆更新机制,包括选择性事实遗忘和输入上下文长度的泛化机制,并证明了它们的有效性。

论文标题:Larimar: Large Language Models with Episodic Memory Control

论文链接https://arxiv.org/abs/2403.11901

图 1: Larimar 模型架构概览

2. 背景介绍

预训练的大型语言模型(LLM)在多种自然语言处理(NLP)任务中表现出色,并常被视为知识存储库。为了维护这些模型在部署后的事实准确性、安全性和道德性,不断更新 LLM 中的知识变得至关重要。因此,开发能够快速更新 LLM 的有效机制显得尤为关键,以确保模型能够保护隐私、消除偏见和跟进新事实的发展。模型编辑应当能够移除 LLM“记忆” 中不再需要的、错误的或过时的信息,并可选择性地替换为新的事实。同样,快速更新 LLM 的能力对于解决输入上下文长度泛化的挑战也非常有帮助,特别是在长上下文实例稀缺的数据集中学习时。一个直接的解决方案是在修正后的或新的数据集上对模型进行微调,但这种方法可能导致过度拟合和灾难性遗忘,因为知识是分布式地编码在 LLM 参数中的。已有研究提出了有效且精确的 LLM 编辑方法,包括训练辅助外部记忆模型或超网络模型与冻结的 LLM 并行工作,或者在 LLM 特征中定位原始事实,然后对相应的权重参数进行局部更新。这两种方法都存在可扩展性问题,因为它们需要重新训练或定位新状态,从而减慢了编辑速度。高内存需求用于存储多个编辑,进一步限制了在连续和批量编辑设置中扩大 LLM 更新规模的能力。这些挑战限制了在实际工业环境中更新大型语言模型的应用。此外,在处理事实编辑和选择性事实遗忘方面,即使是最先进的编辑方法也面临挑战,而在大脑中,新信息的学习和旧信息的遗忘是相互关联的。

相比之下,人类能够迅速更新知识并泛化,这与第一次接触相关实例后的快速学习相似。在大脑中,这种快速学习能力依赖于海马体及其情节记忆功能。补充学习系统(CLS)理论解释了大脑中快速(海马体)和慢速(新皮层)学习系统的耦合,前者从单个实例中学习,而后者对输入分布进行建模。此外,海马体到新皮层的记忆巩固是通过与海马体中编码的体验多次精确或虚假重放的同步激活来促进的,这表明海马体呈现出生成联想网络的形式。

受到这些见解的启发,我们提出了 Larimar——一种通过外部情节记忆控制器增强的 LLM。我们遵循 CLS 的观点:一个海马快速学习系统将样本记录为情节记忆,而一个新皮层慢速学习系统(LLM)学习输入分布的总结统计作为语义记忆。我们的目标是将情节记忆模块作为当前一组事实更新或编辑的全局存储,并将此记忆强制应用于 LLM 解码器。重要的是要学会高效、准确地更新这种记忆,而无需经过任何训练,因为新的编辑会不断到来。

表 1: 各种编辑方法的需求与能力对比分析

3. 方案提出

为了解决上述问题,我们寻求利用一种分层记忆,其灵感类似于 Kanerva 机器,其中记忆的写入和读取被解释为生成模型中的推理。具体而言,我们考虑了将记忆视为确定性的记忆模型,从而允许将 Kanerva 机器中提出的记忆和地址的贝叶斯更新重新形式化为寻找线性系统的最小二乘解。一旦更新,这种快速学习的记忆就会被用来对缓慢的 LLM 解码器施加条件。

与一组样本相关联的全局记忆,以及快速写入内存的能力,使得这种分层记忆框架对于 LLM 有效更新新知识非常有吸引力。在实现方面,记忆通过对通用数据进行端到端梯度下降与 LLM 耦合,并且不需要访问编辑。在推理期间,新数据以一次性方式写入内存,更新后的记忆然后对 LLM 解码进行条件化,以强制执行编辑输出。我们进一步在 Larimar 的一次性记忆更新机制的基础上形式化了无需训练的选择性事实遗忘和防止信息泄露的操作方式。

据我们所知,这是第一项提出并展示在线分布式写入分层条件记忆模型作为 LLM 测试时间适应新知识解决方案的工作。我们在现有基准上针对单一和连续事实编辑任务展示了 Larimar,并与基线方法进行了比较。Larimar 在这些设置中提供了准确和精确的编辑,同时速度比竞争性的模型编辑基线快 10 倍。我们进一步展示了 Larimar 在选择性事实遗忘和防止信息泄漏任务中的有效性。最后,我们提供了一种简单的基于递归搜索的解决方案,使 Larimar 的记忆能够推广到较长的输入上下文。

4. 模型架构

符号: 我们将输入和输出空间定义为XY,分别。该模型包括一个编码器e:XRC 和一个解码器d:RCY,通过一个自适应记忆相连。编码器在 C 维潜在空间中输出。记忆使用 K 行来存储长度为 N 的编码事件,初始状态为M0RK×C,并通过读取和写入权重WW0RN×K 进行更新,从而产生更新后的记忆 M。

4.1. 训练

给定记忆 M,Kanerva 机器的目标是最大化条件对数似然lnp(X|M),其中 X 是一个可交换 (顺序不变) 的事件:X={x1...xN},是输入数据的一个子集,包含 N 个样本。这个条件似然函数的变分下界正在被优化,这一过程与变分自动编码器中的操作类似。因此,该模型学会将 X 压缩到记忆 M 中,M 随后成为一个分布式关联式记忆。实际上,M 是在 Z+ξ的噪声版本上学习的,其中 Z=e(X) 表示一个事件。在本研究的其余部分,我们使用 M 表示依赖于一个事件 X 的后验记忆,而M0 表示先验记忆。读取权重矩阵 W 被视为一个随机变量,以强制模型的生成能力,对此我们使用标准高斯先验p(W)N(0IN×K) 和后验q(W)N(¯Wσ2WIN×K),其中均值¯W 是从每个事件估计的,而σW 是可学习参数。记忆读出结果被获得为Zreadout=WM。整个增强记忆的架构如图 1 所示。

在训练期间,编码器 (e)、关联记忆(M) 和解码器 (d) 三个模块都是根据一个事件 X 联合训练和优化的,使用以下损失函数:

第一项是具有记忆和 W(一个 N×K 矩阵) 的负重构损失。第二项是自动编码器在没有记忆的情况下的负重构损失。第三项是先验p(W) 和后验q(W) 之间的 KL 散度。为了在训练期间维持解码器的性能,添加了一个来自预训练数据的正则化项。

4.2. 记忆推理

一旦M0 通过反向传播进行了训练,则后验记忆 M 就通过求解一个最小化问题,我们使用minM||ZζW0M||2F 来进行一次性更新。这个最小化问题对应于求解一个线性方程组,可以通过计算矩阵伪逆有效地解决。

实现: 我们采用了一个 BERT large 编码器,结合了一个 GPT2-large 或 GPTJ-6B 解码器和一个记忆矩阵 (512x768),用于我们的训练实验,将生成的模型命名为 Larimar-1.3B 和 Larimar-6B, 分别。我们的训练数据包含 760 万个示例,由将 WikiText 文本分割成小块(64 个标记) 构建而成。在测试中,Larimar-1.3B 模型在 1000 个随机 WikiText 样本上达到了 14.6 的困惑度,而 Larimar-6B 模型达到了 15.9,表明添加记忆几乎不会影响性能。我们使用 Adam 优化器、学习率 5e-6 和批量大小 32,训练 Larimar-6B 模型 10 个周期。对于 Larimar-6B 的训练,我们使用了一个设置,包括单节点上的 8 个 NVIDIA A100-80GB GPU,利用 bfloat16 精度和 PyTorch Lightning,结合 DeepSpeed ZeRO Stage 2 进行高效的分布式训练。

5. 记忆操作

写入、读取和生成操作 作用于 Z 编码的三个基本记忆操作 “写入、读取和生成” 被转化为 (Pham et al., 2021) 中所述。

连续写入和遗忘 给定一组初始编码Z0 和写入权重W0,我们初始化记忆矩阵和键协方差矩阵:

为了连续更新记忆Mi1,要么添加一组新的编码Zi,要么遗忘之前写入的一组编码 Z_i,我们联合更新记忆矩阵和键协方差矩阵,对于i=12...:

当写入新编码到记忆时,我们使用αi=1。当遗忘之前用αiwrite=1 写入记忆的任何iwrite<i 的编码时,我们使用αi=1。等式 (4) 以这样一种方式连续更新记忆,即它保持对于不断增长的序列数据的最小二乘解。假设 M_{i-1}是对于编码 Z_{0:i-1}的最小二乘解,即

那么等式 (4) 中的αi=1 确保 M_i 同样是对于Z0:i 的最小二乘解。在情况αi=1Zi=Ziforget 对于某个iforget<i 时,等式 (4) 确保 M_i 是在从数据中移除Ziforget 后的最小二乘解,即

权重可以根据当前记忆计算,Wi=ZiMi1,或根据固定的参考记忆计算,Wi=Zi(M(ref))M(ref) 在所有连续更新过程中保持不变 (即与 i 无关),仅在推理期间使用,并且可以(可选) 使用推理期间遇到的数据构建。如果我们希望从记忆中删除一个之前写入的特定编码,则 M^{(ref)}的固定性质允许在之后的序列iforget>iwrite 的某一点重新计算原始写入键Wiwrite,以便定位记忆中的信息并将其删除。

6. 范围检测器

我们进一步引入了一个可选的范围检测机制,用于判断传入的查询是否与记忆中记录的事实相近似,这一机制在理念上与 SERAC 相似。如果查询落在预定范围内,相应的记忆内容会被读取并传递给解码器,以便进行基于记忆的条件性解码;如果不在范围内,则进行常规的无条件解码。我们考虑了以下两种情况:

基于外部编码的范围检测器(ESD):利用在 11 亿个句子对上训练得到的外部句子编码器(MiniLM)来估计样本嵌入,其输出空间的维度为 384。ESD 将编码后的事实以向量形式存储在其范围存储中。在测试阶段,对于给定的编码输入句子,通过计算与 1 - 最近邻余弦相似度来得出检测分数。对于包含多个句子的输入,系统会先将其拆分为单独的句子,分别处理每个句子,并采用最高的相似度分数。在 EasyEdit 数据集的 3800 个正负样本上的测试中,ESD 模型实现了 2.9% 的等错误率和 0.974 的 F1 分数。

基于内部编码的范围检测器(ISD):使用 Larimar 编码器 e 来嵌入 CounterFact 样本。随后,利用这些编码来训练一个二元范围分类器,正样本源自对原始事实的重新表述,而负样本则对应于接近的事实。

7. 结果展示

表 2: 在 CounterFact 数据集上,Larimar 与其他基线方法在单一事实编辑任务的性能比较

表 3: 在 ZsRE 数据集上进行连续事实编辑时,Larimar 保持对旧编辑记忆的能力

图 2: 在 CounterFact 数据集上进行批量事实编辑的准确度分析

图 3: 在 ZsRE 数据集上完成 3000 次编辑后,Larimar 在处理未见过的重述样本时的平均 F1 分数,证明其在两个数据集上(分别含有 1000 个和 511 个独立事实,每个事实有 10 个或约 20 个重述版本)的优越泛化性能

表 4: 在 CounterFact 和 ZsRE 数据集上,经过 N 次事实写入后移除特定事实,Larimar 准确召回事实的能力

表 5: 在 CounterFact 样本上,针对输入重述的攻击成功率分析(预算限制为 20)

8. 总结

在本项研究中,我们探索了结合大型语言模型(LLM)与动态可更新的分布式情节记忆,作为在线知识更新的新途径。我们的框架采用一次性记忆更新机制和基于记忆的解码条件,展现出在编辑性能上超越传统方法的准确性、精确性和稳定性,速度也显著更高,这些优势不仅体现在单个事实的编辑上,也同样适用于连续序列编辑的复杂场景。此外,我们的记忆更新机制还能够实现快速且选择性的事实遗忘和高效的信息删除。我们还提出了一种简便的处理长输入上下文的方法,通过在 Larimar 的记忆空间中进行递归读取,证明了其在回忆长输入上下文中的事实方面,相较于在更大训练上下文窗口中训练的最新 LLM,有着更好的表现。


本文总阅读量28