Tokens-to-token ViT: 对 token 做编码的纯 transformer ViT,T2T 算引入了 CNN 了吗?¶
[GaintPandaCV 导语] T2T-ViT 是纯 transformer 的形式,先对原始数据做了 token 编码后,再堆叠 Deep-narrow 网络结构的 transformer 模块,实际上 T2T 也引入了 CNN。
引言 ¶
一句话概括:也是纯 transformer 的形式,先对原始数据做了 token 编码后,再堆叠 Deep-narrow 网络结构的 transformer 模块。对 token 编码笔者认为本质上是做了局部特征提取也就是 CNN 擅长做的事情。
原论文作者认为 ViT 效果不及 CNN 的原因:
1、直接将图像分 patch 后生成 token 的方式没法建模局部结构特征 (local structure),比如相邻位置的线,边缘;
2、在限定计算量和限定训练数据数量的条件下,ViT 冗余的注意力骨架网络设计导致提取不到丰富的特征。
所以针对这俩点就提出两个解决方法:
1、找一种高效生成 token 的方法,即 Tokens-to-Token (T2T)
2、设计一个新的纯 transformer 的网络,即 deep-narrow,并对比了目前的流行的 CNN 网络。当然对比完后是作者提出的 Deep-narrow 效果最好。原文的对比实验值得去借鉴 (抄)。
1). 密稠连接,Dense Connection,类比 ResNet 和 DenseNet
2).Deep-narrow 对比 shallow-Wide,类比 Wide-ResNet
3). 通道注意力,类比 SE-ResNet
4). 在多头注意力层加入更多头,类比 ResNeXt
5).Ghost 操作,即减少 conv 的输出通道后再通过 DWConv 和 skip connect 将这俩 concat 起来,类比 GhostNet
实验的结果:给出来了炼丹配方了,这一点还是很良心的,根据现有的 CNN 的模型架构特征改造纯 transformer
Deep-narrow 能提高 VIT 的特征丰富性,模型大小和 MACs 降低,整体效果也提升了;通道注意力对 ViT 也有提升,但 Deep-narrow 结构更加高效;密稠连接会影响性能;
笔者认为最重要的 token 的生成,即可 Tokens-to-token 模块。
直接看图来分析分析,是怎么做 T2T 的,看上面 Firgure 4 橘黄色部分。
步骤 1:有重叠地取图像的区域,实际上这个区域就是做卷积的窗口,这个窗口大小是 7×7,stride 为 4,padding 为 2,然后调用 nn.Unfold 函数将 [7,7] 摊平成[49](也就是把一张饼变成一长条),其实也就是 img2col,这一步命名为 "soft split";
步骤 2:对摊平的长条做变换,这里使用了 transformer,可以用 performer 来降低 transformer 的计算复杂度,这一步命名为 "re-structurization/reconstruction";
步骤 3:将步骤 2 出来的结果 (B,H×W,C)reshape 成一个 4 维度(B,C,H,W) 矩阵;
步骤 4:跟步骤 1 一样,取一个窗口的数值,即 nn.Unfold,这次窗口是 3×3,stride 为 2,padding 为 1;
步骤 5:跟步骤 2 一样,对取到的长条做变换,即可 transformer 或者 performer;
步骤 6:跟步骤 3 一样,reshape 成一个 4 维度矩阵;
步骤 7:跟步骤 4 一样,参数也一样,取出长条;
步骤 8:将步骤 7 出来的长条做一次全连接生成固定的 token 数量。
整个 Tokens-to-token 就完成了。
代码及分析 ¶
看看代码:
class T2T_module(nn.Module):
"""
Tokens-to-Token encoding module
"""
def __init__(self, img_size=224, tokens_type='performer', in_chans=3, embed_dim=768, token_dim=64):
super().__init__()
if tokens_type == 'transformer':
print('adopt transformer encoder for tokens-to-token')
self.soft_split0 = nn.Unfold(kernel_size=(7, 7), stride=(4, 4), padding=(2, 2))
self.soft_split1 = nn.Unfold(kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
self.soft_split2 = nn.Unfold(kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
self.attention1 = Token_transformer(dim=in_chans * 7 * 7, in_dim=token_dim, num_heads=1, mlp_ratio=1.0)
self.attention2 = Token_transformer(dim=token_dim * 3 * 3, in_dim=token_dim, num_heads=1, mlp_ratio=1.0)
self.project = nn.Linear(token_dim * 3 * 3, embed_dim)
elif tokens_type == 'performer':
print('adopt performer encoder for tokens-to-token')
self.soft_split0 = nn.Unfold(kernel_size=(7, 7), stride=(4, 4), padding=(2, 2))
self.soft_split1 = nn.Unfold(kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
self.soft_split2 = nn.Unfold(kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
#self.attention1 = Token_performer(dim=token_dim, in_dim=in_chans*7*7, kernel_ratio=0.5)
#self.attention2 = Token_performer(dim=token_dim, in_dim=token_dim*3*3, kernel_ratio=0.5)
self.attention1 = Token_performer(dim=in_chans*7*7, in_dim=token_dim, kernel_ratio=0.5)
self.attention2 = Token_performer(dim=token_dim*3*3, in_dim=token_dim, kernel_ratio=0.5)
self.project = nn.Linear(token_dim * 3 * 3, embed_dim)
elif tokens_type == 'convolution': # just for comparison with conolution, not our model
# for this tokens type, you need change forward as three convolution operation
print('adopt convolution layers for tokens-to-token')
self.soft_split0 = nn.Conv2d(3, token_dim, kernel_size=(7, 7), stride=(4, 4), padding=(2, 2)) # the 1st convolution
self.soft_split1 = nn.Conv2d(token_dim, token_dim, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) # the 2nd convolution
self.project = nn.Conv2d(token_dim, embed_dim, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) # the 3rd convolution
self.num_patches = (img_size // (4 * 2 * 2)) * (img_size // (4 * 2 * 2)) # there are 3 sfot split, stride are 4,2,2 seperately
def forward(self, x):
# step0: soft split
x = self.soft_split0(x).transpose(1, 2)
# iteration1: re-structurization/reconstruction
x = self.attention1(x)
B, new_HW, C = x.shape
x = x.transpose(1,2).reshape(B, C, int(np.sqrt(new_HW)), int(np.sqrt(new_HW)))
# iteration1: soft split
x = self.soft_split1(x).transpose(1, 2)
# iteration2: re-structurization/reconstruction
x = self.attention2(x)
B, new_HW, C = x.shape
x = x.transpose(1, 2).reshape(B, C, int(np.sqrt(new_HW)), int(np.sqrt(new_HW)))
# iteration2: soft split
x = self.soft_split2(x).transpose(1, 2)
# final tokens
x = self.project(x)
return x
接下来看怎么对生成的 token 做 transformer,看上面 Firgure 4 浅灰色部分,也就是堆叠 transformer layer,最后加一个 MLP 做分类。transformer layer 就是众所周知的了。然后就是怎么做堆叠呢?Deep-narrow 的方式,也就是层数变多,维度变小,“高高瘦瘦”。这部分代码也众所周知了,就不贴代码了。而且个人觉得,虽然作者对 Deep-narrow 的对比实验非常丰富,但我个人主观认为,网络部分是为了结合 T2T,你用其他网络堆叠也是可以的,是一个调参过程。
所以,T2T-ViT 就打完收工了。
这里我有个疑问,所以 T2T 这一部分跟 CNN 有什么区别呢?看看 Figure 3。
我们知道 CNN = unfold + matmul + fold。那么 T2T 模块第一步做了 unfold,然后对取出来的窗口做了 transformer 的非线性变化,这一步我们是不是可以理解为对窗口里面的像素点做了 matmul 呢?这里的 matmul 可能更像是做 attention。然后 reshape 回去相当于做了 fold 操作。笔者认为,T2T 模块,本质上就是做了局部特征提取,也就 CNN 擅长做的事情。
个人主观评价 ¶
T2T 是一篇好文,应该是第一篇提出要对 token 进行处理的 ViT 工作,本意是为了提取更加高效的 token,这样可以减少 token 的数量,那么堆叠 transformer 模块也能降低参数量和计算量。但本质上还是隐式引入了卷积,即有 unfold + matmul + fold = CNN。对比与后来者 ViTAE,T2T 的解决方法其实更加简洁。
本文总阅读量87次