how to train ViT

炼丹大法:

[Google 2021] How to train your ViT? Data, Augmentation, and Regularization in Vision Transformers,Google,rwightman,这些个点其实原论文都提到过了,相当于补充实验了

[Google 2022] Model soups: averaging weights of multiple fine-tuned models improves accuracy without increasing inference time,多个模型权重做平均

[Facebook DeiT 2021] Training data-efficient image transformers & distillation through attention,常规技巧大全

How to train your ViT? Data, Augmentation, and Regularization in Vision Transformers

  1. 动机

    • ViT vs. CNN
      • 没有了平移不变形
      • requires large dataset and strong AugReg
    • 这篇paper的contribution是用大量实验说明,carefully selected regularization and augmentation比憨憨增加10倍数据量有用,简单讲就是在超参方面给一些insight
  2. 方法

    • basic setup

      • pre-training + transfer-learning:是在google research的原版代码上,TPU上跑的

      • inference是在timm的torch ViT,用V100跑的

      • data

        • pretraining:imagenet
        • transfer:cifar
      • models

        • [ViT-Ti, ViT-S, ViT-B and ViT-L][https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py]
        • 决定模型scale的几个要素:
          • depth:12,24,32,40,48
          • embed_dim:192,384,768,1024,1280,1408
          • mlp_ratio:4,48/11,64/13
          • num_heads:3,6,12,16
        • 还有影响计算量的变量:
          • resolution:224,384
          • patch_size:8,16,32
        • 和原版唯一的不同点是去掉了MLP head里面的hidden layer——那个fc-tanh,据说没有performance提升,还会引发optimization instabilities

    • Regularization and data augmentations

      • dropout after each dense-act except the Dense_QKV:0.1
      • stochastic depth:线性增长dbr,till 0.1
      • Mixup:beta分布的alpha
      • RandAugment:randLayers L & randMagnitude M
      • weight decay:0.1 / 0.03,注意这个weight decay是裸的,实际计算是new_p = new_p - p*weight_decay*lr,这个WD*lr可以看作实际的weight decay,也就1e-4/-5量级
    • Pre-training

      • Adam:[0.9,0.999]
      • batch size:4096
      • cosine lr schedule with linear warmup(10k steps)
      • gradients clip at global norm 1
      • crop & random horizontal flipping
      • epochs:ImageNet-1k 300 epochs,ImageNet-21k [30,300] epochs
    • Fine-tuning

      • SGD:0.9
      • batch size:512
      • cosine lr schedule with linear warmup
      • gradients clip at global norm 1
      • resolution:[224,384]
  3. 结论

    • Scaling datasets with AugReg and compute:加大数据量,加强aug&reg

      • proper的AugReg和10x的数据量都能引导模型精度提升,而且是差不多的水平

    • Transfer is the better option:永远用预权重去transfer,尤其大模型

      • 在数据量有限的情况下,train from scratch基本上不能追上transfer learning的精度

    • More data yields more generic models:加大数据,越大范化性越好

    • Prefer augmentation to regularization:非要比的话aug > reg,成年人两个都要

      • for mid-size dataset like ImageNet-1k any kind of AugReg helps
      • for a larger dataset like ImageNet-21k regularization almost hurts,但是aug始终有用

    • Choosing which pre-trained model to transfer

    • Prefer increasing patch-size to shrinking model-size:显存有限情况下优先加大patch size
      • 相似的计算时间,Ti-16要比S-32差
      • 因为patch-size只影响计算量,而model-size影响了参数量,直接影响模型性能

Training data-efficient image transformers & distillation through attention

  1. 动机

    • 大数据+大模型的高精度模型不是谁都负担得起的
    • we produce competitive model
      • use Imagenet only
      • on single computer,8-gpu
      • less than 3 days,53 hours pretraining + 20 hours finetuning
      • 模型:86M,top-1 83.1%
      • 脸厂接地气!!!
    • we also propose a tranformer-specific teacher-student strategy
      • token-based distillation
      • use a convnet as teacher
  2. 论点

    • 本文就是在探索训练transformer的hyper-parameters、各种训练技巧

    • Knowledge Distillation (KD)

      • 本文主要关注teacher-student
      • 用teacher生成的softmax结果(soft label)去训练学生,相当于用student蒸馏teacher
    • the class token
      • a trainable vector
      • 和patch token接在一起
      • 然后接transformer layers
      • 然后 projected with a linear layer to predict the class
      • 这种结构force self-attention在patch token和class token之间进行信息交换
      • 因为class token是唯一的监督信息,而patch token是唯一的输入变量
    • contributions
      • scaling down models:DeiT-S和DeiT-Ti,向下挑战resnet50和resnet18
      • introduce a new distillation procedure based on a distillation token,类似class token的角色
      • 特殊的distillation机制使得transformer相比较于从同类结构更能从convnet上学到更多
      • well transfer
  3. 方法

    • 首先假设我们有了一个strong teacher,我们的任务是通过exploiting the teacher来训练一个高质量的transformer

    • Soft distillation

      • teacher的softmax logits不直接做标签,而是计算两个KL divergence
      • CE + KL loss

    • Hard-label distillation

      • 就直接用作label
      • CE + CE

      • 实验发现hard比soft结果好

    • Distillation token

      • 在token list上再添加一个new token
      • 跟class token的工作任务一样
      • distillation token的优化目标是上述loss的distillation component
      • 与class token相辅相成
      • 作为对比,也尝试了用原本的CE loss训练两个独立的class token,发现这样最终两个class token的cosine similarity高度接近1,说明额外的class token没有带来有用的东西,但是class token和distillation token的相似度最多也就0.93,说明distillation branch给模型add something,【难道不是因为target不同所以才不同吗???】

    • Fine-tuning with distillation

      • finetuning阶段用teacher label还是ground truth label?
      • 实验结果是teacher label好一点
    • Joint classifiers

      • 两个softmax head相加
      • 然后make the prediction
  4. Training details & ablation

    • Initialization

      • Transformers are highly sensitive to initialization,可能会导致不收敛
      • 推荐是weights用truncated normal distribution
    • Data-Augmentation

      • Auto-Augment, Rand-Augment, random erasing, Mixup等等
      • transformers require a strong data augmentation:几乎都有用
      • 除了Dropout:所以我们把Dropout置零了

    • Optimizers & Regularization

      • AdamW
      • 和ViT一样的learning rate
      • 但是much smaller weight decay:发现weight decay会hurt convergence