跳转至

Token Merging: Your ViT But Faster

1. 论文信息

标题:Token Merging: Your ViT But Faster

作者:Daniel Bolya, Cheng-Yang Fu, Xiaoliang Dai, Peizhao Zhang, Christoph Feichtenhofer, Judy Hoffman

原文链接:https://arxiv.org/pdf/2210.09461.pdf

代码链接:https://github.com/facebookresearch/ToMe

2. 介绍

自从 Transformer 的思想从 NLP 引入 VISION——Vision Transformers(ViTs),计算机视觉开始快速发展。然而,与 NLP 不同的是,视觉信息冗余性相对较高,所以出现了很多特定的方法来试图提高计算的效率,减少不必要的注意力计算开销。添加视觉特定的感应模块使 Vision Transformers 能在更少的计算复杂度下获得更好的性能。然而,尽管以 Swin-Transformer 在速度上具有较为明显的优势,但普通而原始的的 ViT 仍然有许多较为优越的性质:

  • 它们由简单的矩阵乘法组成,这使得它们比原始的浮点计算的速度要快;
  • 它们支持强大的自我监督训练前技术,如 MAE ,可以在快速训练的同时提供最先进的结果。
  • 由于缺乏对数据的假设,它们可以在许多模式下不引入过多的变换情况下就可以成功应用。
  • 而且它们在海量数据下也能很好地进行泛化,预训练之后的模型在下游任务中,如 ImageNet-1K 图像分类和检测分割等各种经典的感知任务,具有较为卓越的表现。

但是在一些 edge device 上,受限于计算资源,在这上面计算 Transformer 则不太方便,因为 Transformer 模型较大,难以部署。自然而然的想法就是利用剪纸来对 Vision Transformer 进行加速。但是它需要额外的训练过程,如果进行预训练那么这种计算开销就非常大。与此同时,token 剪枝限制了模型的泛化性能。如果 token 数量发生变化,批处理 (Batch Inference) 就比较困难了。为了解决这个问题,大多数 token 剪枝的工作借助了 Mask,对冗余的 token 进行遮挡。但是这样又引入了额外的计算开销,这就导致了其实希望做到的模型复杂度降低没有完成。另一方面,token 剪枝也会带来一定的信息损失。

经过前面的陈述,其实本文想解决的问题就非常显而易见了:如果在保持 ViT 结构的时候,使得计算复杂度降低,从而增强 ViT 的实用性?本文希望做一个无需训练,而且可以较好完成性能和效率的 trade-off 的 token 的融合方法。对于大模型将会非常友好,无需重新训练是一个非常大的优势。在训练过程中使用 ToMe,可以观察到训练速度增长,而总训练时间缩短了一半。

3. 方法

Token Merging 的基本思路是在一个 ViT 模型中间插入一些减少 token 的模块, 通过这些模块来减少 self-attention 的计算量,从而减少计算的复杂度。基本作法是在每一个层之后减少 m 个 token, 那么一个有 n 层的 Transformer 模型从头到尾减少的 token 数量就是m×n。减少的 token 数量就越多,计算复杂度降低的就越多,随之而来的是更差的精度。而且值得注意的是, 无论一张输入图片有多少个 tokens,和 token 剪枝算法不同,减少 token 的量都是固定的而不是随着 token 数量动态变化的。采用这种模式才能让 ViT 在实际部署中能完成 Batch Inference,可以在真实应用场景下完成部署。

而如何完成 tokens merging 呢?本文定义了 Token Similarity 来完成这一步骤。根据之前的研究,现有的 ViT 实际上都 overparameterized 的,信息存在很强的冗余。而 Transformer 本身就通过 QKV 的 self-attention 解决了这个问题。具体来说,Key 已经总结了每个令牌中包含的信息,用于点积的相似度。因此,我们使用每个标记的键之间的点积相似度度量 (例如,余弦相似度) 来确定哪些包含相似的信息(是重合的。如上图,KQV 都具有不错的表现,所以最后我们选择利用 K 的 cosine 相似度来计算 distance。然后用 mean 来作为聚合的手段,而不是拼接在一起,可以提高效率。

之后就是完成 Bipartite Soft Matching:

  1. 把该模块输入的所有 tokens 分为相同大小的 2 个集合 。
  2. 把每个 token 和与它最相似的 token 之间画一条边。
  3. 只留下最相似的 r 条边, 其余删掉。
  4. 融合仍然相连的条边 ,对 feature 取均值。
  5. 把这两个集合拼在一起, 得到 ToMe 模块的融合结果。

之后对注意力权重进行调整。在 由于 ToMe 融合了多个 token, Attention 矩阵的维度应该是增加的, 融合了 token 之后, 有的 Key 应该占的 Attention 比重大一些, 因为它融合了多个 token 的信息。

通过上式将行向量 直接加在 Attention 矩阵上面,相当于是人为增加了有些 Key 的 attention weight,而这些 key 恰好是发生了融合的 Key。

4. 实验

按照文章的顺序,首先来看消融实验

可以发现 pool 的方式来讲,根据 token 融合的数量来改变的策略上,加权平均改变 attention 还是具有较为明显的精度上的优势(虽然有效率上的劣势)。

可以看到 MAE 的预训练方式上,也有较为有趣的提升,

从这个实验结果也可以思考:为什么论文选择了匹配而不是聚类来作为减少 token 数量的方法。因为聚类算法没有限制每个类的具体数量,因此无法做到在每层精确地减少 tokens 的数量。而匹配算法不同,匹配算法可以做到精确地匹配定量的 tokens,并把它们融合在一起。所以 matching 的方法相较于聚类(kmeans)有较为明显的优势。

结果表明,无论模型的尺寸和类型,ToMe 都能够带来约 2 倍的吞吐量加速。即使减少 96-98% 的 tokens,最大下模型精度几乎保持一致。在 2 倍计算复杂度降低的设定下,AugReg (对比的另一种增速方法)得到的 ViT-B,ViT-S 和 ViT-Ti 都有大约 4-5% 的精度下降。由此可见 ToMe 的优势还是非常明显的。也可以发现 ViT 的冗余度有多高。

5. 结论

在这项工作中,我们引入了 Token Merging(ToMe),通过逐步合并 Token 来增加 ViT 模型的计算焦虑。ToMe 自然利用了输入视觉信息中的冗余,允许将其用于任何具有冗余的模式。ToMe 在跨领域的大型模型上工作得很好,并且减少了训练时间和内存使用,这意味着它可以成为训练大型模型的核心组件。ToMe 指出了一个非常有前景的方向,希望之后大家可以基于 ToMe,创造出更好、更高效的 Transformer。


本文总阅读量833