Meta Pseudo Labels

papers

  • [MPL 2021] Meta Pseudo Labels

  • [UDA 2009] Unsupervised Data Augmentation for Consistency Training

  • [Entropy Minimization 2004] Semi-supervised Learning by Entropy Minimization

Meta Pseudo Labels

  1. 动机

    • semi-supervised learning
      • Pseudo Labels:fixed teacher
      • Meta Pseudo Labels:constantly adapted teacher by the feedback of the student
    • SOTA on ImageNet:top-1 acc 90.2%
  2. 论点

    • Pseudo Labels methods

      • teacher generates pseudo labels on unlabeled images
      • pseudo labeled images are then combined with labeled images to train the student
      • confirmation bias problem:student的精度取决于伪标签的质量
    • we propose Meta Pseudo Labels

      • teacher observes how its pseudo labels affect the student
      • then correct the bias
      • the feedback signal is the performance of the student on the labeled dataset
      • 总的来说,teacher和student是train in parallel的
        • student learns from pseudo labels from the teacher
        • teacher learns from reward signal from how well student perform on labeled set
      • dataset
        • ImageNet as labeled set
        • JFT-300M as unlabeled set
      • model
        • teacher:EfficientNet-L2
        • student:EfficientNet-L2
    • main difference

      • Pseudo Labels方法中,teacher在单向的影响student
      • Meta Pseudo Labels方法中,teacher和student是交互作用的

  3. 方法

    • notations

      • models
        • teacher model T & $\theta_T$
        • student model S & $\theta_S$
      • data
        • labeled set $(x_l, y_l)$
        • unlabeled set $(x_u)$
      • predictions
        • soft predictions by teacher $T(x_u, \theta_T)$
        • student $S(x_u, \theta_S)$ & $S(x_l, \theta_S)$
      • loss
        • $CE(q,p)$,其中$q$是one-hot label,e.g. $CE(y_l, S(x_l, \theta_S))$
    • Pseudo Labels

      • given a fixed teacher $\theta_T$
      • train the student model to minimize the cross-entropy loss on unlabeled data

      • $\theta_S^{PL}$ also achieve a low loss on labeled data

      • $\theta_S^{PL}$ explicitly depends on $\theta_T$:$\theta_S^{PL}(\theta_T)$
      • student loss on labeled data is also a function of $\theta_T$:$L_l(\theta_S^{PL}(\theta_T))$
    • Meta Pseudo Labels

      • intuition:minimize $L_l$ with respect to $\theta_T$

      • 但是实际上dependency of $\theta_S^{PL}(\theta_T)$ on $\theta_T$ 非常复杂

      • 因为我们用了teacher prediction的hard labels去训练student

      • an alternating optimization procedure

      • teacher’s auxiliary losses

        • augment the teacher’s training with a supervised learning objective and a semi-supervise learning objective
        • supervised objective
          • train on labeled data
          • CE
        • semi-supervised objective
          • train on unlabeled data
          • UDA(Unsupervised Data Augmentation):将样本进行简单增强,通过衡量一致性损失,模型的泛化效果得到提升
          • consistency training loss:KL散度
      • finetuning student

        • 在meta pseudo labels训练过程中,student only learns from the unlabeled data
        • 所以在训练过程结束后,可以finetune it on labeled data to improve accuracy
      • overall algorithm

          * 这里面有一处下标写错了,就是teacher的UDA gradient,是在unlabeled data上面算的,那两个$x_l$得改成$x_u$
          * UDA loss论文里使用两个predicted logits的散度,这里是CE
        

Unsupervised Data Augmentation for Consistency Training

  1. 动机

    • data augmentation in previous works
      • 能在一定程度上缓解需要大量标注数据的问题
      • 多用在supervised model上
      • achieved limited gains
    • we propose UDA
      • apply data augmentation in semi-supervised learning setting
      • use harder and more realistic noise to generate the augmented samples
      • encourage the prediction to be consistent between unlabeled & augmented unlabeled sample
      • 在越小的数据集上提升越大
    • verified on
      • six language tasks
      • three vision tasks
        • ImageNet-10%::top1/top5 68.7/88.5%
        • ImageNet-extra unlabeled:top1/top5 79.0/94.5%
  2. 论点

    • semi-supervised learning
      • three categories
        • graph-based label propagation via graph convolution and graph embeddings
        • modeling prediction target as latent variables
        • consistency / smoothness enforcing
      • 最后这一类方法shown to work well,
        • enforce the model predictions on the two examples to be similar
        • 主要区别在于perturbation function的设计
    • we propose UDA
      • use state-of-the-art data augmentation methods
      • we show that better augmentation methods(AutoAugment) lead to greater improvements
      • minimizes the KL divergence
      • can be applied even the class distributions of labeled and unlabeled data mismatch
    • we propose TSA
      • a training technique
      • prevent overfitting when much more unlabeled data is avaiable than labeled data
  3. 方法

    • formulation

      • given an input $x\in U$ and a small noise $\epsilon$

      • compute the output distribution $p_{\theta}(y|x)$ and $p_{\theta}(y|x,\epsilon)$

      • minimize the divergence between two predicted distributions $D(p_{\theta}(y|x)||p_{\theta}(y|x,\epsilon))$

      • add a CE loss on labeled data

      • UDA的优化目标

        • enforce the model to be insensitive to perturbation
        • thus smoother to the changes in the input space
      • $\lambda=1$ for most experiments

      • use different batchsize for labeled & unlabeled

    • Augmentation Strategies for Different Tasks

      • AutoAugment for Image Classification
        • 通过RL搜出来的一组optimal combination of aug operations
      • Back translation for Text Classification
      • TF-IDF based word replacing for Text Classification
    • Trade-off Between Diversity and Validity for Data Augmentation

      • 对原始sample做变换的时候,有一定概率导致gt label变化
      • AutoAugment已经是optmial trade-off了,所以不用管
      • text tasks需要调节temperature
    • Additional Training Techniques

      • TSA(Training Signal Annealing)

        • situation:unlabeled data远比labeled data多的情况,我们需要large enough model去充分利用大数据,但又容易对小trainset过拟合

        • for each training step

          • set a threshold $\frac{1}{K}\leq \eta_t\leq 1$,K is the number of categories
          • 如果样本在gt cls上的预测概率大于这个threshold,就把这个样本的loss去掉
        • $\eta_t$ serves as a ceiling to prevent the model from over-training on examples that the model is already confident about

        • gradually release the training signals of the labeled examples,缓解overfitting

        • schedules of $\eta_t$

          • log-schedule:$\lambda_t = 1-exp(-\frac{t}{T}*5)$
          • linear-schedule:$\lambda_t = \frac{t}{T}$
          • exp-schedule:$\lambda_t = exp((\frac{t}{T}-1)*5)$

          • 如果模型非常容易过拟合,用exp-schedule,反过来(abundant labeled data/effective regularizations),用log-schedule

      • Sharpening Predictions

        • situation:the predicted distributions on unlabeled examples tend to be over-flat across categories,task比较困难,训练数据比较少时,在unlabeled data上每类的预测概率都差不多低,没有倾向性
        • 这时候KL divergence的监督信息就很弱
        • thus we need to sharpen the predicted distribution on unlabeled examples
        • Confidence-based masking:将current model not confident enough to predict的样本过滤掉,只保留最大预测概率大于0.6的样本计算consistency loss
        • Entropy minimization:add an entropy term to the overall objective
        • softmax temperature:在计算softmax时先对logits进行rescale,$Softmax(logits/\tau)$,a lower temperature corresponds to a sharper distribution
        • in practice发现Confidence-based masking和softmax temperature更适用于小labeled set,Entropy minimization适用于相对大一点的labeled set
      • Domain-relevance Data Filtering

        • 其实也是Confidence-based masking,先用labeled data训练一个base model,然后inference the out-of-domain dataset,挑出预测概率较大的样本

Semi-supervised Learning by Entropy Minimization