SegmentAnything

  1. overview

    • task:propose a prompt-able segmentation task
    • model:build SAM model,promptable,enabling zero-shot
    • dataset:1B数据集,滚标注的半自动标注方案

  2. introduction

    • foundation model
      • models that can generalize to unseen/untrained tasks and data
      • often implemented with prompt engineering
      • 图像领域最有名的就是CLIP/Align:aligns paired text and images from the web use contrastive learning
      • SAM也被构建成foundation model
        • pre-train it on a broad dataset,enables powerful generalization
        • aim to solve a range of downstream segmentation problems on new data distributions using prompt engineering
    • model constrains
      • flexible prompt:point, box, and mask / initial results of free-form text prompts
      • real-time:一张图的embedding可以被各种prompt复用,轻量的prompt encoder&mask decoder
      • ambiguity-aware:一个prompt可以预测多个mask,allowing SAM to naturally handle ambiguity
    • data engine
      • stage1:assisted-manual,传统滚标注
      • stage2:semi-automatic,用prompt的方式自动标注部分object,人工标注其余的
      • stage3:fully automatic,将prompt设定为a regular grid of foreground points,自动全图标注
  3. Segment Anything Task

    • promptable segmentation task
      • prompt:a set of foreground / background points, a rough box or mask, free-form text…
      • return a valid segmentation mask:multiple object when ambiguous
    • pre-training
      • input:for each training sample,simulates a sequence of prompts (e.g., points, boxes, masks)
      • supervision:compares the model’s mask predictions against the ground truth
    • zero-shot transfer
      • at inference time,模型已经具备对任何prompt作出相应的能力
      • thus downstream tasks can be solved by engineering appropriate prompts,所以downstream tasks可以建模成prompts engineering的任务
  4. Segment Anything Model

    • image encoder

      • MAE pretrained ViT:ViT-H/16 with 14x14 windowed attention,1024×1024
      • minimally adapted to process high resolution inputs
      • outputs x16 image features:64x64
      • into image embeddings:conv1,dim256 - LN - conv3-dim256 - LN
    • prompt encoder

      • sparse prompts:
        • point: a positional embedding + learned bg/fg embedding
        • box: an embedding pair, (1) a left-top corner positional embedding + earned left-top corner embedding, (2) a bottom-right
        • points/box: positional encodings + learned embeddings for each prompt type
        • text: any text encoder from CLIP
        • dim256
      • dense prompts:
        • input x4 ds masks: downscale additional x4
          • 2x2,s2,dim4 conv - LN - GeLU
          • 2x2,s2,dim16 conv - LN - GeLU
          • 1x1,s1,dim256 conv - LN - GeLU
          • elewise-add on image embeddings
        • if there is no mask prompt: add a learned ‘no mask’ mask embedding
    • mask decoder

      • insert a learned output token embedding:类似于cls token的东西,Nx256
      • use a two-layer decoder
      • 4 steps inside each decoder
        • self-attention on the tokens
        • cross-attention from tokens to images:QKVdim128
        • a point-wise MLP updates each token
        • cross-attention from the image to tokens:QKVdim128
      • MLP:drop0.1,with residual,use LN,dim2048
      • geometric和task type的强约束
        • image emb每次参与attention都要加上positional encodings
        • tokens每次参与attention都要加上原始的original prompt tokens
      • dynamic prediction head
        • upsample the updated image embedding by 4× with two transposed convs
          • 2×2, s2, dim64 TransConv - GeLU - LN - 2×2, s2, dim32 TransConv - GeLU
          • 得到upscaled mask embedding,64x64x32
        • 用updated image embedding和updated tokens再做一次attn,提取output token,经过一个3-layer MLP得到一个vector:Nx32
        • 最后通过dot product得到Nx64x64的prediction mask
      • ambiguity-aware
        • use 3 output tokens,predict 3 masks(whole, part, and subpart)
        • 计算3组mask loss,但是只回传the lowest loss的梯度
        • 在output tokens上再接一个small head,用来estimates IOU,inference time用这个IoU prediction来rank3个mask
      • loss
        • mask loss:focal loss*20 + dice loss*1
        • iou loss:pred iou和mask iou的MSE
      • Training algorithm for one sample
        • init prompt:随机选取points/boxes作为input prompt
          • points从gt masks中随机抽
          • boxes基于gt masks的外接框做随机noisy deviation
        • 接下来的points从pred mask和gt mask的error region中抽取
        • 然后把pred mask也作为一个mask prompt输入给model:用的是unthresholded mask logits,而不是binary mask
        • 这样迭代8+2个iteration