当前位置: 网站首页 技术文章 正文

Llama 3+Mamba 强强联合!蒸馏到线性RNN,推理速度翻倍再加成

来源:互联网 发布时间:2024-09-18 23:20:06

将Transformers融入Mamba框架,能够实现推理速度最高达1.6倍的增强。

Llama3+Mamba强强联合!蒸馏到线性RNN,推理速度翻倍再加成

此成果出自Together AI团队之手,他们成功地将Transformers和Mamba模型相结合,并设计了专门针对混合模型的加速推理算法。

值得一提的是,Mamba架构的设计者、FlashAttention的创造者Tri Dao也参与了该项目。

Together AI的创始人兼首席执行官认为,Transformers与Mamba的融合代表了大规模模型未来的发展方向。

为了将Transformers提炼至Mamba,首先需要从Transformers到线性RNN的初始化过程。

研究者们发现,Transformers的注意力机制与RNN的计算流程之间有着潜在的关联。

通过将Transformers的注意力机制线性化,可以建立起两者之间的桥梁。

借助这种关联,预训练的Transformers模型参数可以映射到Mamba模型中。

在参数初始化完成后,研究者们采用了三阶段的提炼流程,以进一步增强Mamba模型的能力,使其更有效地吸收Transformers的知识。

首阶段是以伪标签为基础的提炼——使用预训练的Transformers教师模型生成无标签数据上的伪标签,然后让Mamba学生模型依据这些伪标签进行训练。

这一阶段的损失函数结合了KL散度和交叉熵损失,旨在模仿教师模型的输出分布并匹配伪标签。

第二阶段则是在指令数据集上进行监督微调,使用带有标签的指令数据集进行训练。

最后阶段则是利用人类反馈数据,通过基于奖励模型的优化方法来进行调整。

在配备有8块80GB A100 GPU的环境下,每个混合模型的完整提炼过程仅需不到五天即可完成。

Llama3+Mamba强强联合!蒸馏到线性RNN,推理速度翻倍再加成

通过上述提炼流程,研究人员获得了Transformers-Mamba混合模型,并提出了“推测解码”算法以加快推理过程。

推测解码的核心理念是运用一个轻量级的草稿模型来预测多个标记,随后用验证模型来确认这些预测。

这种方法极大地提高了解码的并行性,从而加速了生成过程。

草稿模型通常是一个小型的Transformers,它可以根据现有的上下文预测出接下来的K个标记。

对于预测出的K个标记,Transformers层可以并行处理这些标记,计算其隐藏状态;而Mamba层则按顺序逐个处理每个标记,首先计算当前标记的隐藏状态,并与先前的状态进行对比。

如果当前标记正确,则将其加入已接受的序列中,并更新最新的隐藏状态(但不保存中间状态);若当前标记错误,则停止后续标记的处理,并将最新的隐藏状态回退至上一个已接受的标记。

如果序列中的所有K个标记都被接受,则将它们加入输出序列,并继续预测下一轮标记。

如果出现标记被拒绝的情况,则从首个被拒标记处截断预测序列,并从该点重新开始预测。

实验结果显示,在单轮对话(AlpacaEval)和多轮对话(MT-Bench)任务中,混合模型的表现与Llama-3持平甚至更佳。

此外,不同混合比例的模型测试显示,1:1混合比例的模型表现最佳。

在无需示例的通用NLP任务评估中,混合模型的平均得分优于同等规模的RNN模型。

在少量示例的OpenLLM排行榜上,混合模型与顶级开源RNN模型表现相当,并在GSM8K和CRUX任务上超越了相应的Instruct模型。

除了模型性能外,研究人员还测试了推测解码算法带来的速度增益。

首先在纯Mamba模型上进行测试,结果表明在2.8B和7B模型上,与传统解码方式相比,推理速度提升了1.7到2.6倍。

Llama3+Mamba强强联合!蒸馏到线性RNN,推理速度翻倍再加成

进一步地,在提炼的Zephyr和Llama混合模型上测试,结果表明Zephyr混合模型的推理速度提高了超过1.8倍,而Llama混合模型也有约1.6倍的速度提升。

相关教程