TokenLearner

recollect:

[ViT 2020] AN IMAGE IS WORTH 16X16 WORDS: TRANSFORMERS FOR IMAGE RECOGNITION AT SCALE,Google,开启了vision transformer的固定范式,都是切割patches作为tokens,这也对应了文本的词/字符切割,但是一个patch和一个词向量的信息量是不一样的(像素信息更低级)

[TokenLearner 2022] TokenLearner: What Can 8 Learned Tokens Do for Images and Videos? Google,使用更少数量的、能够挖掘重要信息的learnable tokens,

repo:https://github.com/google-research/scenic/tree/main/scenic/projects/token_learner

unofficial keras repo:https://github.com/ariG23498/TokenLearner

TokenLearner: What Can 8 Learned Tokens Do for Images and Videos?

  1. 动机

    • tokens
      • previous:densely sampled patches
      • ours:a handful of adpatively learned tokens
    • learn to mine important tokens in visual data
      • find a few import visual tokens
      • enable long range pair-wise attention
    • applicable to both image & video tasks
      • strong performance
      • computationally more efficient
    • comparable results are verified on classifications tasks
      • 与state-of-the-arts on ImageNet对比
      • video datasets, including Kinetics-400, Kinetics-600, Charades, and AViD
  2. 论点

    • the main challenge of ViTs
      • require too many tokens:按照16x16的切割方式,512x512的图像也对应着1024个tokens
      • transformer block的computation和memory是基于token length平方型增长的
      • 因此限制了larger image/longer video
    • thus we propose TokenLearner
      • a learnable module
      • take image as input
      • generates a small set of tokens
      • idea很直接:找到图像的重点区域regions-of-importance,然后用重点区域生成token
      • 实验发现保留8-16个(之前transformer block通常保留200-500通道数)就能够保持甚至提升精度,同时降低flops
  3. 方法

    • TokenLearner

      • formulation

        • a space-time tensor input $X$:$X \in R^{T \times H \times W \times C}$
        • a temporal slice $X_t$:$X_t \in R^{H \times W \times C}$
        • T是时间维度,如果是image的话T=1,HWC是常规的长/宽/通道数
        • for every time frame t,we learn to generate a series of tokens $Z_t$ from the input frame $X_t$:$Z_t=[z_i]_{i=1}^S$
        • thus we use a tokenizer function $A_i$:$z_i=A_i(X_t)$,adaptively selects important combination of pixels
        • 这样的function我们有S个,而且S远远小于HW,通常S=8
      • tokenizer function

        • implemented with a spatial attention mechanism
        • 首先生成一个spatial weight map (size HW1)
        • 然后乘在$X_t$上,得到an intermediate weighted tensor (size HWC)
        • 最后进行spatial维度的global averge pooling,将weighted maps转化成vector (size C)
        • 所有的resulting tokens are gathered to form the output $Z_t =[z_i]_{i=1}^S\in R^{S \times C}$
        • spatial attention的实现有两种
          • 本文v1.0使用了一层/多层卷积(channel=S)+sigmoid
          • 本文v1.1使用了一个MLP(dense-gelu-dense)
          • (这两个版本的参数量差距巨大啊)
      • 图:将$R^{HWC}$的input image稀疏映射到$R^{SC}$

    • TokenFuser

      • after the Transformer layers,此时的tensor flow还是$R^{SC}$

      • 引入TokenFuser

        • fuse information across tokens,融合所有token
        • remap the representation back to origin resolution,重映射
      • 首先做fuse:give tokens $Y\in R^{ST \times C}$,乘以一个learnable weight $M (ST \times ST)$,得到tensor $\in R^{ST \times C}$,可以理解为空间(或时空)关联

      • 然后做remap,对每个temporal slice $Y_t \in R^{SC}$:

        • $X_t^{j+1} = B(Y_t, X_t^j) = B_w Y_t + X_t^j = \beta_i(X_t^j)Y_t+X_t^j$
          • $X_t^j$是TokenLinear的残差输入,也就是原图HWC,等待被reweight的分支
          • $X_t^{j+1}$是模块输出
          • $Y_t^j$是TokenFuser的fuse这步的结果,对应图上transformer output SC
        • $\beta_i()$是个dense+sigmoid,作用在原图上,得到HWS的weight tensor $B_w$
        • 然后乘上Y得到HWC
        • 再加上这个残差
    • 整体架构

      • 整体计算流程

      • 两种模型结构(有/没有TokenFuser)

  4. 实验

    • settings

      • tobeadded
    • TokenFuser的ablation实验:整体有提升,模型越大提升越不明显