knowledge distillation

paper collection

[Similarity KD 2019] Similarity-Preserving Knowledge Distillation

[CWD 2020] Channel-wise Knowledge Distillation for Dense Prediction

Similarity-Preserving Knowledge Distillation

  1. 与其他KD methods的区别:

    • the student is not required to mimic the representation space of the teacher, but rather to preserve the pairwise similarities in its own representation space
    • 监督student和teacher的similarity
    • 适用场景:
      • feature shape对不齐:channel/stride,也可以用1x1 conv对齐
      • CNN/transformer之间的KD,模型本质diversity就不同
  2. method

    • given activation map(feature)$F$

      • compute similarity:$F*F^T/L2norm(F)$
      • compute KD loss:$\frac{1}{b^2}MSE(t,s)$
    • code

      1
      2
      3
      4
      5
      6
      7
      8
      9
      10
      11
      12
      13
      14
      15
      16
      17
      18
      import torch
      import torch.nn as nn
      import torch.nn.functional as F


      def similarity_loss(self, f_s, f_t):
      bs = f_s.shape[0]
      f_s = f_s.view(bs, -1)
      f_t = f_t.view(bs, -1)

      G_s = f_s * f_s.T # bxb
      G_s = F.normalize(G_s, p=2, dim=1)
      G_t = f_t * f_t.T
      G_t = F.normalize(G_t, p=2, dim=1)

      G_diff = G_t - G_s
      loss = (G_diff * G_diff).view(-1).sum() / (bs * bs)
      return loss
  3. 其他KD methods

    • response based:KL divergence loss

      • 作用于softmax以后的probs(可以有temperature factor)

      • $L_{RD} (p_t, p_s) = L_R (p_t, p_s)$

      • $L_R()$通常是KL divergence loss

      • KL divergence loss

        1
        2
        3
        4
        5
        6
        7
        8
        9
        10
        11
        12
        # pred在调用KLDivLoss方法计算loss时先做log,防止先两个normed prob先做除法损失精度
        pred = F.log_softmax(torch.randn(3, 5, requires_grad=True))
        target = F.softmax(torch.rand(3, 5))
        loss = nn.KLDivLoss(reduction="batchmean")(pred, target)
        # target在调用KLDivLoss时也可以先做log
        log_target = F.log_softmax(torch.rand(3, 5))
        output = nn.KLDivLoss(reduction="batchmean", log_target=True)(pred, log_target)

        def kl_categorical(p_logit, q_logit):
        p = F.softmax(p_logit, dim=-1)
        _kl = torch.sum(p * (F.log_softmax(p_logit, dim=-1) - F.log_softmax(q_logit, dim=-1)), -1)
        return torch.mean(_kl)
    • feature based

      • $L_{FD} (f_t, f_s) = L_F (\Phi_s(f_t), \Phi_s(f_s))$
      • $\Phi()$是transform function,用来align feature dimension
      • $L_F()$是similarity function,通常可以是L1、L2、L_CE、L_MMD
      • MMD:maximum mean discrepancy