SAM2架构概览和详细解读

前文SAM2图像分割模型的微调记录了如何使用SAM2在下游数据集进行微调、预测、评估等。

本文将通过对SAM2源码的阅读和debug,进一步了解SAM2的架构和流程。

本文目前只考虑图片任务,不考虑视频任务涉及的模块,所以SAM2架构上就是由image encoder、mask decoder、prompt encoder三个模块构成。
SAM2架构

🔥项目结构

assets:数据集文件夹;
checkpoints:下载官方提供的权重;
notebooks:使用jupyter进行预测、全景分割,用于上手体验;
sam2:SAM2本体
sam2_config:SAM2的架构配置,模型参数设置;
sav_dataset:官方的1B大小数据集相关;
tools:视频预测用的工具;

所以实际上sam2文件夹当中就包含了所有重要的内容。

🔥流程

根据训练的过程来解读SAM2是如何进行数据处理、前向传播等操作的。如果只想理解训练过程(验证和预测也差不多)并了解数据是如何在SAM2传播最终输出分割结果的,可以通过这节了解。

SAM2架构简图

实际上去掉了SAM2新增的用于视频任务的内存机制,在图像任务中其结构可以用1代的来表示(上图)。

🚠构建模型

sam2_model = build_sam2(model_cfg, sam2_checkpoint, device=device)
predictor = SAM2ImagePredictor(sam2_model)

首先,无论是预测还是训练都有以上两行构建SAM2。

加载权重:build_sam.py

def build_sam2(
    config_file,
    ckpt_path=None,
    device="cuda",
    mode="eval",
    hydra_overrides_extra=[],
    apply_postprocessing=True,
):

    if apply_postprocessing:
        hydra_overrides_extra = hydra_overrides_extra.copy()
        hydra_overrides_extra += [
            # dynamically fall back to multi-mask if the single mask is not stable
            "++model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true",
            "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05",
            "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98",
        ]
    # Read config and init model
    cfg = compose(config_name=config_file, overrides=hydra_overrides_extra)
    OmegaConf.resolve(cfg)
    model = instantiate(cfg.model, _recursive_=True)
    _load_checkpoint(model, ckpt_path)
    model = model.to(device)
    if mode == "eval":
        model.eval()
    return model

build_sam2()这个函数从config.yaml里面加载对应规模模型参数,构建模型。

SAM2核心:sam2_image_predictor.py

class SAM2ImagePredictor:
    def __init__(
        self,
        sam_model: SAM2Base,
        mask_threshold=0.0,
        max_hole_area=0.0,
        max_sprinkle_area=0.0,
    ) -> None:
        """
        Uses SAM-2 to calculate the image embedding for an image, and then
        allow repeated, efficient mask prediction given prompts.

        Arguments:
          sam_model (Sam-2): The model to use for mask prediction.
          mask_threshold (float): The threshold to use when converting mask logits
            to binary masks. Masks are thresholded at 0 by default.
          fill_hole_area (int): If fill_hole_area > 0, we fill small holes in up to
            the maximum area of fill_hole_area in low_res_masks.
        """
        super().__init__()
        self.model = sam_model
        self._transforms = SAM2Transforms(
            resolution=self.model.image_size,
            mask_threshold=mask_threshold,
            max_hole_area=max_hole_area,
            max_sprinkle_area=max_sprinkle_area,
        )

        # Predictor state
        self._is_image_set = False
        self._features = None
        self._orig_hw = None
        # Whether the predictor is set for single image or a batch of images
        self._is_batch = False

        # Predictor config
        self.mask_threshold = mask_threshold

        # Spatial dim for backbone feature maps
        self._bb_feat_sizes = [
            (256, 256),
            (128, 128),
            (64, 64),
        ]

SAM2ImagePredictor包含了模型的提示处理、图像编码、预测,使得我们用它就可以进行预测所需的大部分操作。它的初始化其实就是在模型之外加了个SAM2Transforms,它包含对数据的预处理(resize、归一化等)和输出masks的后处理。

SAM2ImagePredictor还提供了predict_predictpredict_batch几个函数,用于预测,内容和训练的过程差不多,只不过没有梯度。

🚠图像编码

前向传播开始,首先对图像进行编码:

predictor.set_image_batch(image_list)
    def set_image_batch(
        self,
        image_list: Union[torch.Tensor, List[Union[np.ndarray]]],
    ) -> None:
        """
        Calculates the image embeddings for the provided image batch, allowing
        masks to be predicted with the 'predict_batch' method.

        Arguments:
          image_list (List[np.ndarray]): The input images to embed in RGB format. The image should be in HWC format if np.ndarray
          with pixel values in [0, 255].
        """
        self.reset_predictor()
        # assert isinstance(image_list, list)
        self._orig_hw = []
        if isinstance(image_list, list):
            for image in image_list:
                assert isinstance(
                    image, np.ndarray
                ), "Images are expected to be an np.ndarray in RGB format, and of shape  HWC"
                self._orig_hw.append(image.shape[:2])
            # Transform the image to the form expected by the model
            img_batch = self._transforms.forward_batch(image_list)
            img_batch = img_batch.to(self.device)
        else:                   # 是torch张量
            for image in torch.unbind(image_list, dim=0):
                self._orig_hw.append(image.shape[1:])
            img_batch = image_list.to(self.device)
        batch_size = img_batch.shape[0]
        assert (
            len(img_batch.shape) == 4 and img_batch.shape[1] == 3
        ), f"img_batch must be of size Bx3xHxW, got {img_batch.shape}"
        logging.info("Computing image embeddings for the provided images...")
        backbone_out = self.model.forward_image(img_batch)
        _, vision_feats, _, _ = self.model._prepare_backbone_features(backbone_out)
        # Add no_mem_embed, which is added to the lowest rest feat. map during training on videos
        if self.model.directly_add_no_mem_embed:
            vision_feats[-1] = vision_feats[-1] + self.model.no_mem_embed

        feats = [
            feat.permute(1, 2, 0).view(batch_size, -1, *feat_size)
            for feat, feat_size in zip(vision_feats[::-1], self._bb_feat_sizes[::-1])
        ][::-1]
        self._features = {"image_embed": feats[-1], "high_res_feats": feats[:-1]}
        self._is_image_set = True
        self._is_batch = True
        logging.info("Image embeddings computed.")

SAM2ImagePredictorset_image()set_image_batch()两个函数用于把输入的图像经过SAM2Transforms处理后,输入到image encoder当中进行编码,经过简单的处理后存在self._features当中,image_embedhigh_res_feats分别是金字塔结构输出的最后一层和前面若干层的特征,前面的层的特征会有更高的分辨率,有助于输出的掩码还原分辨率

🚠提示编码

    if use_prompt:
        mask_input, unnorm_coords, labels, unnorm_box = predictor._prep_prompts(input_point, input_label, box=None, mask_logits=None, normalize_coords=True)      # 有提示
        sparse_embeddings, dense_embeddings = predictor.model.sam_prompt_encoder(points=(unnorm_coords, labels), boxes=None, masks=None)
    else:
        sparse_embeddings, dense_embeddings = predictor.model.sam_prompt_encoder(points=None, boxes=None, masks=None)       # 无提示

以上是对提示(点及其label、框、掩码等)进行编码的示例。

  • SAM2ImagePredictor._prep_prompts():用来预处理提示,把列表这些格式转为张量;

  • predictor.model.sam_prompt_encoder:则是直接找到prompt encoder进行前向传播,如下,输出的sparse_embeddings由点和框这种稀疏提示嵌入得来,dense_embeddings则由掩码这种密集提示嵌入得来。

        bs = self._get_batch_size(points, boxes, masks)
        sparse_embeddings = torch.empty(
            (bs, 0, self.embed_dim), device=self._get_device()
        )
        if points is not None:
            coords, labels = points
            point_embeddings = self._embed_points(coords, labels, pad=(boxes is None))
            sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1)
        if boxes is not None:
            box_embeddings = self._embed_boxes(boxes)
            sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1)

        if masks is not None:
            dense_embeddings = self._embed_masks(masks)
        else:
            dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand(
                bs, -1, self.image_embedding_size[0], self.image_embedding_size[1]
            )

        return sparse_embeddings, dense_embeddings

🚠掩码解码

接下来就是将提示和图像的编码通过mask decoder进行解码输出掩码:

    high_res_features = [feat_level[-1].unsqueeze(0) for feat_level in predictor._features["high_res_feats"]]
    low_res_masks, prd_scores, _, _ = predictor.model.sam_mask_decoder(
        image_embeddings=predictor._features["image_embed"],
        image_pe=predictor.model.sam_prompt_encoder.get_dense_pe(),
        sparse_prompt_embeddings=sparse_embeddings,
        dense_prompt_embeddings=dense_embeddings,
        multimask_output=False,
        repeat_image=False,
        high_res_features=high_res_features
    )
    prd_masks = predictor._transforms.postprocess_masks(low_res_masks, predictor._orig_hw[-1])

解码:mask decoder.py

sam_mask_decoder是找到mask decoder进行前向传播,输入包含图像编码image_embeddings、图像位置编码image_pe、提示的稀疏编码sparse_prompt_embeddings、提示的密集编码dense_prompt_embeddings、是否预测多掩码multimask_output、图像高分辨率特征high_res_features

        masks, iou_pred, mask_tokens_out, object_score_logits = self.predict_masks(
            image_embeddings=image_embeddings,
            image_pe=image_pe,
            sparse_prompt_embeddings=sparse_prompt_embeddings,
            dense_prompt_embeddings=dense_prompt_embeddings,
            repeat_image=repeat_image,
            high_res_features=high_res_features,
        )

        # Select the correct mask or masks for output
        if multimask_output:
            masks = masks[:, 1:, :, :]
            iou_pred = iou_pred[:, 1:]
        elif self.dynamic_multimask_via_stability and not self.training:
            masks, iou_pred = self._dynamic_multimask_via_stability(masks, iou_pred)
        else:
            masks = masks[:, 0:1, :, :]
            iou_pred = iou_pred[:, 0:1]

        if multimask_output and self.use_multimask_token_for_obj_ptr:
            sam_tokens_out = mask_tokens_out[:, 1:]  # [b, 3, c] shape
        else:
            # Take the mask output token. Here we *always* use the token for single mask output.
            # At test time, even if we track after 1-click (and using multimask_output=True),
            # we still take the single mask token here. The rationale is that we always track
            # after multiple clicks during training, so the past tokens seen during training
            # are always the single mask token (and we'll let it be the object-memory token).
            sam_tokens_out = mask_tokens_out[:, 0:1]  # [b, 1, c] shape

        # Prepare output
        return masks, iou_pred, sam_tokens_out, object_score_logits

重点显然在predict_masks()函数:

    def predict_masks(
        self,
        image_embeddings: torch.Tensor,
        image_pe: torch.Tensor,
        sparse_prompt_embeddings: torch.Tensor,
        dense_prompt_embeddings: torch.Tensor,
        repeat_image: bool,
        high_res_features: Optional[List[torch.Tensor]] = None,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Predicts masks. See 'forward' for more details."""
        # Concatenate output tokens
        s = 0
        if self.pred_obj_scores:
            output_tokens = torch.cat(
                [
                    self.obj_score_token.weight,
                    self.iou_token.weight,
                    self.mask_tokens.weight,
                ],
                dim=0,
            )
            s = 1
        else:
            output_tokens = torch.cat(
                [self.iou_token.weight, self.mask_tokens.weight], dim=0
            )
        output_tokens = output_tokens.unsqueeze(0).expand(
            sparse_prompt_embeddings.size(0), -1, -1
        )
        tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)

        # Expand per-image data in batch direction to be per-mask
        if repeat_image:
            src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
        else:
            # assert image_embeddings.shape[0] == tokens.shape[0]
            src = image_embeddings
        src = src + dense_prompt_embeddings
        assert (
            image_pe.size(0) == 1
        ), "image_pe should have size 1 in batch dim (from `get_dense_pe()`)"
        pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
        b, c, h, w = src.shape

        # Run the transformer
        hs, src = self.transformer(src, pos_src, tokens)
        iou_token_out = hs[:, s, :]
        mask_tokens_out = hs[:, s + 1 : (s + 1 + self.num_mask_tokens), :]

        # Upscale mask embeddings and predict masks using the mask tokens
        src = src.transpose(1, 2).view(b, c, h, w)
        if not self.use_high_res_features:
            upscaled_embedding = self.output_upscaling(src)
        else:
            dc1, ln1, act1, dc2, act2 = self.output_upscaling
            feat_s0, feat_s1 = high_res_features
            upscaled_embedding = act1(ln1(dc1(src) + feat_s1))
            upscaled_embedding = act2(dc2(upscaled_embedding) + feat_s0)

        hyper_in_list: List[torch.Tensor] = []
        for i in range(self.num_mask_tokens):
            hyper_in_list.append(
                self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :])
            )
        hyper_in = torch.stack(hyper_in_list, dim=1)
        b, c, h, w = upscaled_embedding.shape
        masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w)

        # Generate mask quality predictions
        iou_pred = self.iou_prediction_head(iou_token_out)
        if self.pred_obj_scores:
            assert s == 1
            object_score_logits = self.pred_obj_score_head(hs[:, 0, :])
        else:
            # Obj scores logits - default to 10.0, i.e. assuming the object is present, sigmoid(10)=1
            object_score_logits = 10.0 * iou_pred.new_ones(iou_pred.shape[0], 1)

        return masks, iou_pred, mask_tokens_out, object_score_logits

掩码解码器结构

代码与上图结合看(斜体代表图中对应的名称),输入端,输出层权重output_tokensiou_tokenmask_tokens)与稀疏提示编码拼接成tokens(output tokens+prompt tokens),而图像编码与密集提示编码(掩码)直接相加成src(image embedding),然后把tokenssrc、图像位置编码pos_src输入到self.transformer(SAM2用的2层的双路transformer,红色部分)。

输出端,高分辨率特征(stride 4, 8 feats)可选择地被用在上采样环节(转置卷积,conv.trans)生成masks,注意力输出层的输出hs会分解成3个部分,分别经过各自的MLP(右边3个mlp)输出掩码、IoU预测分数、遮挡分数,第一部分在经过mlp之前会直接作为目标帧的指针(遮挡分数和指针仅用于视频任务)。masks由上采样模块和第一份MLP的输出进行矩阵乘法得到,如果一次生成多个masks,则每个mask都会对应一个MLP(self.output_hypernetworks_mlps里面)。

后处理生成最终masks

输出的低分辨率图像需要转换为高分辨率的,用到之前说的SAM2Transforms的后处理函数postprocess_masks()

    def postprocess_masks(self, masks: torch.Tensor, orig_hw) -> torch.Tensor:
        """
        Perform PostProcessing on output masks.
        """
        from sam2.utils.misc import get_connected_components

        masks = masks.float()
        if self.max_hole_area > 0:
            # Holes are those connected components in background with area <= self.fill_hole_area
            # (background regions are those with mask scores <= self.mask_threshold)
            mask_flat = masks.flatten(0, 1).unsqueeze(1)  # flatten as 1-channel image
            labels, areas = get_connected_components(mask_flat <= self.mask_threshold)
            is_hole = (labels > 0) & (areas <= self.max_hole_area)
            is_hole = is_hole.reshape_as(masks)
            # We fill holes with a small positive mask score (10.0) to change them to foreground.
            masks = torch.where(is_hole, self.mask_threshold + 10.0, masks)

        if self.max_sprinkle_area > 0:
            labels, areas = get_connected_components(mask_flat > self.mask_threshold)
            is_hole = (labels > 0) & (areas <= self.max_sprinkle_area)
            is_hole = is_hole.reshape_as(masks)
            # We fill holes with negative mask score (-10.0) to change them to background.
            masks = torch.where(is_hole, self.mask_threshold - 10.0, masks)

        masks = F.interpolate(masks, orig_hw, mode="bilinear", align_corners=False)
        return masks

简单来说,该函数完成了填坑、去噪、插值还原尺寸,返回了真正的原尺寸预测掩码masks。

==至此,我们对训练、验证、预测都会用到的关键部分代码都进行了了解,还有更细节的前向传播和网络结构参见下一大节。==

🔥网络结构

接下来直接细看网络各个部分的结构,以及用了写什么技术。sam2/modeling文件夹包含了整个SAM2的结构。

modeling/
├── backbones/
│   ├── hierdet.py
│   ├── image_encoder.py
│   └── utils.py
└── sam/
│   ├── mask_decoder.py
│   ├── prompt_encoder.py
│   ├── transformer.py
├── memory_attention.py
├── memory_encoder.py
├── position_encoding.py
├── sam2_base.py
└── sam2_utils.py

可以看到图像编码器在backbones,mask解码器和提示编码器在sam,SAM2新增的内存机制在modeling(与视频相关,不细看)。

🚠工具:sam2_util.py

里面有SAM2的一些工具函数和类:

  • select_closest_cond_frames():寻找最近的提示帧,用于视频;

  • get_1d_sine_pe():获取一维的正弦位置编码;

  • get_activation_fn():根据配置获取激活函数;

  • 自定义的dropoutDropPathMLPLayerNorm2d层。

🚠位置编码:position_encoding.py

提供了三种位置编码,2D正弦位置编码PositionEmbeddingSine、随机空间频率位置编码PositionEmbeddingRandom、旋转位置编码RoPE。

SAM2的image encoder用的应该是PositionEmbeddingSine;而prompt encoder用的是PositionEmbeddingRandom,它将坐标映射到[0, 1]区间,然后使用随机生成的矩阵进行变换,最后通过正弦和余弦函数生成编码。

🚠SAM2网络:sam2_base.py

SAM2Base就是整个SAM2网络,初始化函数是一堆网络结构方面的参数设置,不详细展开了,其中最重要的是调用了SAM2Base._build_sam_heads()

    def _build_sam_heads(self):
        """Build SAM-style prompt encoder and mask decoder."""
        self.sam_prompt_embed_dim = self.hidden_dim
        self.sam_image_embedding_size = self.image_size // self.backbone_stride

        # build PromptEncoder and MaskDecoder from SAM
        # (their hyperparameters like `mask_in_chans=16` are from SAM code)
        self.sam_prompt_encoder = PromptEncoder(
            embed_dim=self.sam_prompt_embed_dim,
            image_embedding_size=(
                self.sam_image_embedding_size,
                self.sam_image_embedding_size,
            ),
            input_image_size=(self.image_size, self.image_size),
            mask_in_chans=16,
        )
        self.sam_mask_decoder = MaskDecoder(
            num_multimask_outputs=3,
            transformer=TwoWayTransformer(
                depth=2,
                embedding_dim=self.sam_prompt_embed_dim,
                mlp_dim=2048,
                num_heads=8,
            ),
            transformer_dim=self.sam_prompt_embed_dim,
            iou_head_depth=3,
            iou_head_hidden_dim=256,
            use_high_res_features=self.use_high_res_features_in_sam,
            iou_prediction_use_sigmoid=self.iou_prediction_use_sigmoid,
            pred_obj_scores=self.pred_obj_scores,
            pred_obj_scores_mlp=self.pred_obj_scores_mlp,
            use_multimask_token_for_obj_ptr=self.use_multimask_token_for_obj_ptr,
            **(self.sam_mask_decoder_extra_args or {}),
        )
        if self.use_obj_ptrs_in_encoder:
            # a linear projection on SAM output tokens to turn them into object pointers
            self.obj_ptr_proj = torch.nn.Linear(self.hidden_dim, self.hidden_dim)
            if self.use_mlp_for_obj_ptr_proj:
                self.obj_ptr_proj = MLP(
                    self.hidden_dim, self.hidden_dim, self.hidden_dim, 3
                )
        else:
            self.obj_ptr_proj = torch.nn.Identity()
        if self.proj_tpos_enc_in_obj_ptrs:
            # a linear projection on temporal positional encoding in object pointers to
            # avoid potential interference with spatial positional encoding
            self.obj_ptr_tpos_proj = torch.nn.Linear(self.hidden_dim, self.mem_dim)
        else:
            self.obj_ptr_tpos_proj = torch.nn.Identity()

_build_sam_heads()主要构建的是prompt encoder和mask decoder两个部分,image encoder是外部传入的(yaml文件设置)。所以可以看到yaml文件里面是不会有这两个部分的参数的,因为在这里设置了。

  • _forward_sam_heads():prompt encoder和mask decoder两个部分的前向传播;

  • forward_image():视觉主干image encoder的前向传播;

  • _prepare_backbone_features():之前提到过,图像编码后会解析输出用作后续的特征;

  • SAM2Base禁用了forward,需要在之前提到的SAM2ImagePredictor里面调用。

  • SAM2Base中除了以上函数,剩下的都是和视频任务有关的。

==至此我们完成了对整个SAM2模型有宏观的了解,可以看到 SAM2ImagePredictor 和 SAM2VideoPredictor 本质上算是SAM2Base的“子类”,只不过两者不同程度地用到SAM2Base的功能,所以官方分开给不同需求者使用。==

🚠图像编码器:image_encoder.py

image encoder的构建是读取如下的yaml文件,这是SAM2 b+模型的配置,省去了内存机制部分。

model:
  _target_: sam2.modeling.sam2_base.SAM2Base
  image_encoder:
    _target_: sam2.modeling.backbones.image_encoder.ImageEncoder
    scalp: 1
    trunk:
      _target_: sam2.modeling.backbones.hieradet.Hiera
      embed_dim: 112
      num_heads: 2
    neck:
      _target_: sam2.modeling.backbones.image_encoder.FpnNeck
      position_encoding:
        _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
        num_pos_feats: 256
        normalize: true
        scale: null
        temperature: 10000
      d_model: 256
      backbone_channel_list: [896, 448, 224, 112]
      fpn_top_down_levels: [2, 3]  # output level 0 and 1 directly use the backbone features
      fpn_interp_model: nearest
class ImageEncoder(nn.Module):
    def __init__(
        self,
        trunk: nn.Module,
        neck: nn.Module,
        scalp: int = 0,
    ):
        super().__init__()
        self.trunk = trunk
        self.neck = neck
        self.scalp = scalp
        assert (
            self.trunk.channel_list == self.neck.backbone_channel_list
        ), f"Channel dims of trunk and neck do not match. Trunk: {self.trunk.channel_list}, neck: {self.neck.backbone_channel_list}"

    def forward(self, sample: torch.Tensor):
        # Forward through backbone
        features, pos = self.neck(self.trunk(sample))
        if self.scalp > 0:
            # Discard the lowest resolution features
            features, pos = features[: -self.scalp], pos[: -self.scalp]

        src = features[-1]
        output = {
            "vision_features": src,
            "vision_pos_enc": pos,
            "backbone_fpn": features,
        }
        return output

ImageEncoder由trunk、neck组成,它们已经在build_sam.py里面读取yaml配置时完成了加载,所以这里直接赋值给属性就好了。而对于这两个部分,可以类比YOLO系列网络的trunk、neck部分(下图左、中部分),它们理念上几乎完全一样,随着代码的解读就能明白这点。

YOLOv8

前向传播也比较简单,生成的output包含了最终层输出vision_features、位置编码vision_pos_enc、多尺度特征backbone_fpn,具体需要从这两个部分去详细理解。所以接下来我们看看image encoder的两个部分。

trunk:hieradet.py

class MultiScaleBlock(nn.Module):
    def __init__(
        self,
        dim: int,
        dim_out: int,
        num_heads: int,
        mlp_ratio: float = 4.0,
        drop_path: float = 0.0,
        norm_layer: Union[nn.Module, str] = "LayerNorm",
        q_stride: Tuple[int, int] = None,
        act_layer: nn.Module = nn.GELU,
        window_size: int = 0,
    ):
        super().__init__()

        if isinstance(norm_layer, str):
            norm_layer = partial(getattr(nn, norm_layer), eps=1e-6)

        self.dim = dim
        self.dim_out = dim_out
        self.norm1 = norm_layer(dim)

        self.window_size = window_size

        self.pool, self.q_stride = None, q_stride
        if self.q_stride:
            self.pool = nn.MaxPool2d(
                kernel_size=q_stride, stride=q_stride, ceil_mode=False
            )

        self.attn = MultiScaleAttention(
            dim,
            dim_out,
            num_heads=num_heads,
            q_pool=self.pool,
        )
        self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()

        self.norm2 = norm_layer(dim_out)
        self.mlp = MLP(
            dim_out,
            int(dim_out * mlp_ratio),
            dim_out,
            num_layers=2,
            activation=act_layer,
        )

        if dim != dim_out:
            self.proj = nn.Linear(dim, dim_out)

MultiScaleBlock是基本的transformer块,包括了该有的归一化层LayerNorm、注意力层MultiScaleAttention、dropout层DropPath、FFN层MLP,其中MultiScaleAttention是同文件下实现的多头点积注意力机制(和ViT的没多大区别)。前向传播没什么特别的就不展示。

class Hiera(nn.Module):
    """
    Reference: https://arxiv.org/abs/2306.00989
    """

    def __init__(
        self,
        embed_dim: int = 96,  # initial embed dim
        num_heads: int = 1,  # initial number of heads
        drop_path_rate: float = 0.0,  # stochastic depth
        q_pool: int = 3,  # number of q_pool stages
        q_stride: Tuple[int, int] = (2, 2),  # downsample stride bet. stages
        stages: Tuple[int, ...] = (2, 3, 16, 3),  # blocks per stage
        dim_mul: float = 2.0,  # dim_mul factor at stage shift
        head_mul: float = 2.0,  # head_mul factor at stage shift
        window_pos_embed_bkg_spatial_size: Tuple[int, int] = (14, 14),
        # window size per stage, when not using global att.
        window_spec: Tuple[int, ...] = (
            8,
            4,
            14,
            7,
        ),
        # global attn in these blocks
        global_att_blocks: Tuple[int, ...] = (
            12,
            16,
            20,
        ),
        return_interm_layers=True,  # return feats from every stage
    ):
        super().__init__()

        assert len(stages) == len(window_spec)
        self.window_spec = window_spec

        depth = sum(stages)
        self.q_stride = q_stride
        self.stage_ends = [sum(stages[:i]) - 1 for i in range(1, len(stages) + 1)]
        assert 0 <= q_pool <= len(self.stage_ends[:-1])
        self.q_pool_blocks = [x + 1 for x in self.stage_ends[:-1]][:q_pool]
        self.return_interm_layers = return_interm_layers

        self.patch_embed = PatchEmbed(
            embed_dim=embed_dim,
        )
        # Which blocks have global att?
        self.global_att_blocks = global_att_blocks

        # Windowed positional embedding (https://arxiv.org/abs/2311.05613)
        self.window_pos_embed_bkg_spatial_size = window_pos_embed_bkg_spatial_size
        self.pos_embed = nn.Parameter(
            torch.zeros(1, embed_dim, *self.window_pos_embed_bkg_spatial_size)
        )
        self.pos_embed_window = nn.Parameter(
            torch.zeros(1, embed_dim, self.window_spec[0], self.window_spec[0])
        )

        dpr = [
            x.item() for x in torch.linspace(0, drop_path_rate, depth)
        ]  # stochastic depth decay rule

        cur_stage = 1
        self.blocks = nn.ModuleList()

        for i in range(depth):
            dim_out = embed_dim
            # lags by a block, so first block of
            # next stage uses an initial window size
            # of previous stage and final window size of current stage
            window_size = self.window_spec[cur_stage - 1]

            if self.global_att_blocks is not None:
                window_size = 0 if i in self.global_att_blocks else window_size

            if i - 1 in self.stage_ends:
                dim_out = int(embed_dim * dim_mul)
                num_heads = int(num_heads * head_mul)
                cur_stage += 1

            block = MultiScaleBlock(
                dim=embed_dim,
                dim_out=dim_out,
                num_heads=num_heads,
                drop_path=dpr[i],
                q_stride=self.q_stride if i in self.q_pool_blocks else None,
                window_size=window_size,
            )

            embed_dim = dim_out
            self.blocks.append(block)

        self.channel_list = (
            [self.blocks[i].dim_out for i in self.stage_ends[::-1]]
            if return_interm_layers
            else [self.blocks[-1].dim_out]
        )
  
    def _get_pos_embed(self, hw: Tuple[int, int]) -> torch.Tensor:
        h, w = hw
        window_embed = self.pos_embed_window
        pos_embed = F.interpolate(self.pos_embed, size=(h, w), mode="bicubic")
        pos_embed = pos_embed + window_embed.tile(
            [x // y for x, y in zip(pos_embed.shape, window_embed.shape)]
        )
        pos_embed = pos_embed.permute(0, 2, 3, 1)
        return pos_embed

    def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
        x = self.patch_embed(x)
        # x: (B, H, W, C)

        # Add pos embed
        x = x + self._get_pos_embed(x.shape[1:3])

        outputs = []
        for i, blk in enumerate(self.blocks):
            x = blk(x)
            if (i == self.stage_ends[-1]) or (
                i in self.stage_ends and self.return_interm_layers
            ):
                feats = x.permute(0, 3, 1, 2)
                outputs.append(feats)

        return outputs

Hiera就是trunk本身,__init__()看似比较复杂,可以先看forward():就是将图片分块(patch)后,做位置编码(注意这里不是正弦编码而是用一个层去学)后将x输入到编码器(多层MultiScaleBlock)中,编码器每一级的特征都会存到outputs中,传递给neck部分,这和transformer工作当中的编码器没多大区别。

所以反过来看,init也就是做了分块嵌入层位置编码层编码器的构建。其中PatchEmbed是用卷积分块,每个MultiScaleBlock有不同的窗口(patch)大小,相当于输出特征有多个不同的尺度。

neck:FpnNeck

class FpnNeck(nn.Module):
    """
    A modified variant of Feature Pyramid Network (FPN) neck
    (we remove output conv and also do bicubic interpolation similar to ViT
    pos embed interpolation)
    """

    def __init__(
        self,
        position_encoding: nn.Module,
        d_model: int,
        backbone_channel_list: List[int],
        kernel_size: int = 1,
        stride: int = 1,
        padding: int = 0,
        fpn_interp_model: str = "bilinear",
        fuse_type: str = "sum",
        fpn_top_down_levels: Optional[List[int]] = None,
    ):
        """Initialize the neck
        :param trunk: the backbone
        :param position_encoding: the positional encoding to use
        :param d_model: the dimension of the model
        :param neck_norm: the normalization to use
        """
        super().__init__()
        self.position_encoding = position_encoding
        self.convs = nn.ModuleList()
        self.backbone_channel_list = backbone_channel_list
        for dim in backbone_channel_list:
            current = nn.Sequential()
            current.add_module(
                "conv",
                nn.Conv2d(
                    in_channels=dim,
                    out_channels=d_model,
                    kernel_size=kernel_size,
                    stride=stride,
                    padding=padding,
                ),
            )

            self.convs.append(current)
        self.fpn_interp_model = fpn_interp_model
        assert fuse_type in ["sum", "avg"]
        self.fuse_type = fuse_type

        # levels to have top-down features in its outputs
        # e.g. if fpn_top_down_levels is [2, 3], then only outputs of level 2 and 3
        # have top-down propagation, while outputs of level 0 and level 1 have only
        # lateral features from the same backbone level.
        if fpn_top_down_levels is None:
            # default is to have top-down features on all levels
            fpn_top_down_levels = range(len(self.convs))
        self.fpn_top_down_levels = list(fpn_top_down_levels)

图像编码器的neck部分是一个特征金字塔,初始化了位置编码PositionEmbeddingSine),以及卷积网络(列表)。

    def forward(self, xs: List[torch.Tensor]):
        out = [None] * len(self.convs)
        pos = [None] * len(self.convs)
        assert len(xs) == len(self.convs)
        # fpn forward pass
        # see https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/fpn.py
        prev_features = None
        # forward in top-down order (from low to high resolution)
        n = len(self.convs) - 1
        for i in range(n, -1, -1):
            x = xs[i]
            lateral_features = self.convs[n - i](x)
            if i in self.fpn_top_down_levels and prev_features is not None:
                top_down_features = F.interpolate(
                    prev_features.to(dtype=torch.float32),
                    scale_factor=2.0,
                    mode=self.fpn_interp_model,
                    align_corners=(
                        None if self.fpn_interp_model == "nearest" else False
                    ),
                    antialias=False,
                )
                prev_features = lateral_features + top_down_features
                if self.fuse_type == "avg":
                    prev_features /= 2
            else:
                prev_features = lateral_features
            x_out = prev_features
            out[i] = x_out
            pos[i] = self.position_encoding(x_out).to(x_out.dtype)

        return out, pos

前向传播,将trunk传来的多尺度特征分别输入到各自的卷积网络中生成lateral_features,在self.fpn_top_down_levels中的层需要进行上采样,其结果会作为上一大节的==🚠图像编码==当中的high_res_feats用于协助还原高分辨率的masks。卷积输出out在进行位置编码得到pos

==所以回到ImageEncoder的层次看,整个图片编码器就是像YOLO一样提取多层不同尺度特征,然后分别输入到金字塔当中进一步提取特征,最后输出图像编码。区别就是主干用的是注意力而不是卷积、金字塔不像YOLO的PAN那样复杂地来回融合。==

🚠提示编码器:prompt_encoder.py

class PromptEncoder(nn.Module):
    def __init__(
        self,
        embed_dim: int,
        image_embedding_size: Tuple[int, int],
        input_image_size: Tuple[int, int],
        mask_in_chans: int,
        activation: Type[nn.Module] = nn.GELU,
    ) -> None:
        """
        Encodes prompts for input to SAM's mask decoder.

        Arguments:
          embed_dim (int): The prompts' embedding dimension
          image_embedding_size (tuple(int, int)): The spatial size of the
            image embedding, as (H, W).
          input_image_size (int): The padded size of the image as input
            to the image encoder, as (H, W).
          mask_in_chans (int): The number of hidden channels used for
            encoding input masks.
          activation (nn.Module): The activation to use when encoding
            input masks.
        """
        super().__init__()
        self.embed_dim = embed_dim
        self.input_image_size = input_image_size
        self.image_embedding_size = image_embedding_size
        self.pe_layer = PositionEmbeddingRandom(embed_dim // 2)

        self.num_point_embeddings: int = 4  # pos/neg point + 2 box corners
        point_embeddings = [
            nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings)
        ]
        self.point_embeddings = nn.ModuleList(point_embeddings)
        self.not_a_point_embed = nn.Embedding(1, embed_dim)

        self.mask_input_size = (
            4 * image_embedding_size[0],
            4 * image_embedding_size[1],
        )
        self.mask_downscaling = nn.Sequential(
            nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2),
            LayerNorm2d(mask_in_chans // 4),
            activation(),
            nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2),
            LayerNorm2d(mask_in_chans),
            activation(),
            nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1),
        )
        self.no_mask_embed = nn.Embedding(1, embed_dim)

PromptEncoder主要初始化了位置编码pe_layer、点嵌入层point_embeddings(2+2共4个编码层,分别给正/负两种提示点坐标框的两个顶点坐标用)、mask嵌入层mask_downscaling(用卷积降采样),没有提示时也会用长度1的嵌入层代替。

对三种提示的编码如下,它们forward()当中用到(上一大节==🚠提示编码==部分介绍过)。

    def _embed_points(
        self,
        points: torch.Tensor,
        labels: torch.Tensor,
        pad: bool,
    ) -> torch.Tensor:
        """Embeds point prompts."""
        points = points + 0.5  # Shift to center of pixel
        if pad:
            padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device)
            padding_label = -torch.ones((labels.shape[0], 1), device=labels.device)
            points = torch.cat([points, padding_point], dim=1)
            labels = torch.cat([labels, padding_label], dim=1)
        point_embedding = self.pe_layer.forward_with_coords(
            points, self.input_image_size
        )
        point_embedding[labels == -1] = 0.0
        point_embedding[labels == -1] += self.not_a_point_embed.weight
        point_embedding[labels == 0] += self.point_embeddings[0].weight
        point_embedding[labels == 1] += self.point_embeddings[1].weight
        point_embedding[labels == 2] += self.point_embeddings[2].weight
        point_embedding[labels == 3] += self.point_embeddings[3].weight
        return point_embedding

    def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:
        """Embeds box prompts."""
        boxes = boxes + 0.5  # Shift to center of pixel
        coords = boxes.reshape(-1, 2, 2)
        corner_embedding = self.pe_layer.forward_with_coords(
            coords, self.input_image_size
        )
        corner_embedding[:, 0, :] += self.point_embeddings[2].weight
        corner_embedding[:, 1, :] += self.point_embeddings[3].weight
        return corner_embedding

    def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor:
        """Embeds mask inputs."""
        mask_embedding = self.mask_downscaling(masks)
        return mask_embedding
  • 点编码:对pos/neg两种提示点在位置编码之后,再分别用第0、1两个编码层编码即可,在没有框提示时会用0把2维坐标补齐到4维,保证4个编码层都有参与前向传播;

  • 框编码:直接对2个坐标点进行位置编码,再分别用第2、3两个编码层进行编码;

  • 掩码编码:输入到降采样的卷积网络即可。

==提示编码器的结构很简单,参数量很小,在没有提示时整个编码器几乎不发挥作用。==

🚠掩码解码器:mask_decoder.py

class MaskDecoder(nn.Module):
    def __init__(
        self,
        *,
        transformer_dim: int,
        transformer: nn.Module,
        num_multimask_outputs: int = 3,
        activation: Type[nn.Module] = nn.GELU,
        iou_head_depth: int = 3,
        iou_head_hidden_dim: int = 256,
        use_high_res_features: bool = False,
        iou_prediction_use_sigmoid=False,
        dynamic_multimask_via_stability=False,
        dynamic_multimask_stability_delta=0.05,
        dynamic_multimask_stability_thresh=0.98,
        pred_obj_scores: bool = False,
        pred_obj_scores_mlp: bool = False,
        use_multimask_token_for_obj_ptr: bool = False,
    ) -> None:
        """
        Predicts masks given an image and prompt embeddings, using a
        transformer architecture.

        Arguments:
          transformer_dim (int): the channel dimension of the transformer
          transformer (nn.Module): the transformer used to predict masks
          num_multimask_outputs (int): the number of masks to predict
            when disambiguating masks
          activation (nn.Module): the type of activation to use when
            upscaling masks
          iou_head_depth (int): the depth of the MLP used to predict
            mask quality
          iou_head_hidden_dim (int): the hidden dimension of the MLP
            used to predict mask quality
        """
        super().__init__()
        self.transformer_dim = transformer_dim
        self.transformer = transformer

        self.num_multimask_outputs = num_multimask_outputs

        self.iou_token = nn.Embedding(1, transformer_dim)
        self.num_mask_tokens = num_multimask_outputs + 1
        self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim)

        self.pred_obj_scores = pred_obj_scores
        if self.pred_obj_scores:
            self.obj_score_token = nn.Embedding(1, transformer_dim)
        self.use_multimask_token_for_obj_ptr = use_multimask_token_for_obj_ptr

        self.output_upscaling = nn.Sequential(
            nn.ConvTranspose2d(
                transformer_dim, transformer_dim // 4, kernel_size=2, stride=2
            ),
            LayerNorm2d(transformer_dim // 4),
            activation(),
            nn.ConvTranspose2d(
                transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2
            ),
            activation(),
        )
        self.use_high_res_features = use_high_res_features
        if use_high_res_features:
            self.conv_s0 = nn.Conv2d(
                transformer_dim, transformer_dim // 8, kernel_size=1, stride=1
            )
            self.conv_s1 = nn.Conv2d(
                transformer_dim, transformer_dim // 4, kernel_size=1, stride=1
            )

        self.output_hypernetworks_mlps = nn.ModuleList(
            [
                MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3)
                for i in range(self.num_mask_tokens)
            ]
        )

        self.iou_prediction_head = MLP(
            transformer_dim,
            iou_head_hidden_dim,
            self.num_mask_tokens,
            iou_head_depth,
            sigmoid_output=iou_prediction_use_sigmoid,
        )
        if self.pred_obj_scores:
            self.pred_obj_score_head = nn.Linear(transformer_dim, 1)
            if pred_obj_scores_mlp:
                self.pred_obj_score_head = MLP(transformer_dim, transformer_dim, 1, 3)

        # When outputting a single mask, optionally we can dynamically fall back to the best
        # multimask output token if the single mask output token gives low stability scores.
        self.dynamic_multimask_via_stability = dynamic_multimask_via_stability
        self.dynamic_multimask_stability_delta = dynamic_multimask_stability_delta
        self.dynamic_multimask_stability_thresh = dynamic_multimask_stability_thresh

解码器核心是个transformer,除此之外有预测IoU和mask用的编码层self.iou_tokenself.mask_tokens)、上采样mask的转置卷积网络self.output_upscaling生成多掩码的MLP列表self.output_hypernetworks_mlps预测IoU分数的MLPself.iou_prediction_head预测遮挡分数的MLPself.pred_obj_score_head,后面这些就没有什么特别之处。

前向传播在网络中的过程在上一大节的==🚠掩码解码==已经详细介绍了,这里补充一下forward()函数当中在网络输出后调用的masks, iou_pred = self._dynamic_multimask_via_stability(masks, iou_pred)

    def _get_stability_scores(self, mask_logits):
        """
        Compute stability scores of the mask logits based on the IoU between upper and
        lower thresholds, similar to https://github.com/fairinternal/onevision/pull/568.
        """
        mask_logits = mask_logits.flatten(-2)
        stability_delta = self.dynamic_multimask_stability_delta
        area_i = torch.sum(mask_logits > stability_delta, dim=-1).float()
        area_u = torch.sum(mask_logits > -stability_delta, dim=-1).float()
        stability_scores = torch.where(area_u > 0, area_i / area_u, 1.0)
        return stability_scores

    def _dynamic_multimask_via_stability(self, all_mask_logits, all_iou_scores):
        """
        When outputting a single mask, if the stability score from the current single-mask
        output (based on output token 0) falls below a threshold, we instead select from
        multi-mask outputs (based on output token 1~3) the mask with the highest predicted
        IoU score. This is intended to ensure a valid mask for both clicking and tracking.
        """
        # The best mask from multimask output tokens (1~3)
        multimask_logits = all_mask_logits[:, 1:, :, :]
        multimask_iou_scores = all_iou_scores[:, 1:]
        best_scores_inds = torch.argmax(multimask_iou_scores, dim=-1)
        batch_inds = torch.arange(
            multimask_iou_scores.size(0), device=all_iou_scores.device
        )
        best_multimask_logits = multimask_logits[batch_inds, best_scores_inds]
        best_multimask_logits = best_multimask_logits.unsqueeze(1)
        best_multimask_iou_scores = multimask_iou_scores[batch_inds, best_scores_inds]
        best_multimask_iou_scores = best_multimask_iou_scores.unsqueeze(1)

        # The mask from singlemask output token 0 and its stability score
        singlemask_logits = all_mask_logits[:, 0:1, :, :]
        singlemask_iou_scores = all_iou_scores[:, 0:1]
        stability_scores = self._get_stability_scores(singlemask_logits)
        is_stable = stability_scores >= self.dynamic_multimask_stability_thresh

        # Dynamically fall back to best multimask output upon low stability scores.
        mask_logits_out = torch.where(
            is_stable[..., None, None].expand_as(singlemask_logits),
            singlemask_logits,
            best_multimask_logits,
        )
        iou_scores_out = torch.where(
            is_stable.expand_as(singlemask_iou_scores),
            singlemask_iou_scores,
            best_multimask_iou_scores,
        )
        return mask_logits_out, iou_scores_out

该函数的目的是在仅生成单个掩码时,如果基于输出token 0的掩码的稳定性分数低于阈值,则从多掩码输出(基于输出token 1~3)中选择具有最高预测IoU分数的掩码。稳定性分数就是根据 area_u(上界阈值对应的区域面积)是否大于0来计算稳定性分数 stability_scores,本质上是个类似IoU的值。所以该函数就是对生成掩码的一种优化手段。

双路Transformer:transformer.py

掩码解码器双路transformer

这是掩码解码器用到的transformer模块,TwoWay指的应该是有两个输入形成两条传播路径,里面有:

  • class Attention(nn.Module):掩码解码器内++所有的注意力机制++都是这个模块,就是一个点积多头注意力层,添加了线性层可以先改变输入q、k、v的维度再进行注意力计算;

  • class TwoWayAttentionBlock(nn.Module):如上图小框,包含一个自注意力、两个交叉注意力和一个MLP,还有LN层。

  • class TwoWayTransformer(nn.Module):如上图大框,包含2层TwoWayAttentionBlock和一个输出注意力层self.final_attn_token_to_image

代码如下,前向传播确实是上图中的传播路径,TwoWayAttentionBlock.forward()querieskeys分别是提示编码(稀疏)和加了mask提示的图像编码(密集);值得注意的是两个交叉注意力的q、k是相反的,符合上图的 token to image 和 image to token;还有TwoWayAttentionBlock最终输出的queries是MLP之后的,而keys是最后一个交叉注意力之后的,符合图中两个不同位置的出口箭头,这两个输出如果是最后一个块的则会输出给self.final_attn_token_to_image和降采样模块,否则会作为q、k输入到下一个块。

class TwoWayTransformer(nn.Module):
    def __init__(
        self,
        depth: int,
        embedding_dim: int,
        num_heads: int,
        mlp_dim: int,
        activation: Type[nn.Module] = nn.ReLU,
        attention_downsample_rate: int = 2,
    ) -> None:
        """
        A transformer decoder that attends to an input image using
        queries whose positional embedding is supplied.

        Args:
          depth (int): number of layers in the transformer
          embedding_dim (int): the channel dimension for the input embeddings
          num_heads (int): the number of heads for multihead attention. Must
            divide embedding_dim
          mlp_dim (int): the channel dimension internal to the MLP block
          activation (nn.Module): the activation to use in the MLP block
        """
        super().__init__()
        self.depth = depth
        self.embedding_dim = embedding_dim
        self.num_heads = num_heads
        self.mlp_dim = mlp_dim
        self.layers = nn.ModuleList()

        for i in range(depth):
            self.layers.append(
                TwoWayAttentionBlock(
                    embedding_dim=embedding_dim,
                    num_heads=num_heads,
                    mlp_dim=mlp_dim,
                    activation=activation,
                    attention_downsample_rate=attention_downsample_rate,
                    skip_first_layer_pe=(i == 0),
                )
            )

        self.final_attn_token_to_image = Attention(
            embedding_dim, num_heads, downsample_rate=attention_downsample_rate
        )
        self.norm_final_attn = nn.LayerNorm(embedding_dim)

    def forward(
        self,
        image_embedding: Tensor,
        image_pe: Tensor,
        point_embedding: Tensor,
    ) -> Tuple[Tensor, Tensor]:
        """
        Args:
          image_embedding (torch.Tensor): image to attend to. Should be shape
            B x embedding_dim x h x w for any h and w.
          image_pe (torch.Tensor): the positional encoding to add to the image. Must
            have the same shape as image_embedding.
          point_embedding (torch.Tensor): the embedding to add to the query points.
            Must have shape B x N_points x embedding_dim for any N_points.

        Returns:
          torch.Tensor: the processed point_embedding
          torch.Tensor: the processed image_embedding
        """
        # BxCxHxW -> BxHWxC == B x N_image_tokens x C
        bs, c, h, w = image_embedding.shape
        image_embedding = image_embedding.flatten(2).permute(0, 2, 1)
        image_pe = image_pe.flatten(2).permute(0, 2, 1)

        # Prepare queries
        queries = point_embedding
        keys = image_embedding

        # Apply transformer blocks and final layernorm
        for layer in self.layers:
            queries, keys = layer(
                queries=queries,
                keys=keys,
                query_pe=point_embedding,
                key_pe=image_pe,
            )

        # Apply the final attention layer from the points to the image
        q = queries + point_embedding
        k = keys + image_pe
        attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys)
        queries = queries + attn_out
        queries = self.norm_final_attn(queries)

        return queries, keys
class TwoWayAttentionBlock(nn.Module):
    def __init__(
        self,
        embedding_dim: int,
        num_heads: int,
        mlp_dim: int = 2048,
        activation: Type[nn.Module] = nn.ReLU,
        attention_downsample_rate: int = 2,
        skip_first_layer_pe: bool = False,
    ) -> None:
        """
        A transformer block with four layers: (1) self-attention of sparse
        inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp
        block on sparse inputs, and (4) cross attention of dense inputs to sparse
        inputs.

        Arguments:
          embedding_dim (int): the channel dimension of the embeddings
          num_heads (int): the number of heads in the attention layers
          mlp_dim (int): the hidden dimension of the mlp block
          activation (nn.Module): the activation of the mlp block
          skip_first_layer_pe (bool): skip the PE on the first layer
        """
        super().__init__()
        self.self_attn = Attention(embedding_dim, num_heads)
        self.norm1 = nn.LayerNorm(embedding_dim)

        self.cross_attn_token_to_image = Attention(
            embedding_dim, num_heads, downsample_rate=attention_downsample_rate
        )
        self.norm2 = nn.LayerNorm(embedding_dim)

        self.mlp = MLP(
            embedding_dim, mlp_dim, embedding_dim, num_layers=2, activation=activation
        )
        self.norm3 = nn.LayerNorm(embedding_dim)

        self.norm4 = nn.LayerNorm(embedding_dim)
        self.cross_attn_image_to_token = Attention(
            embedding_dim, num_heads, downsample_rate=attention_downsample_rate
        )

        self.skip_first_layer_pe = skip_first_layer_pe

    def forward(
        self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor
    ) -> Tuple[Tensor, Tensor]:
        # Self attention block
        if self.skip_first_layer_pe:
            queries = self.self_attn(q=queries, k=queries, v=queries)
        else:
            q = queries + query_pe
            attn_out = self.self_attn(q=q, k=q, v=queries)
            queries = queries + attn_out
        queries = self.norm1(queries)

        # Cross attention block, tokens attending to image embedding
        q = queries + query_pe
        k = keys + key_pe
        attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys)
        queries = queries + attn_out
        queries = self.norm2(queries)

        # MLP block
        mlp_out = self.mlp(queries)
        queries = queries + mlp_out
        queries = self.norm3(queries)

        # Cross attention block, image embedding attending to tokens
        q = queries + query_pe
        k = keys + key_pe
        attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries)
        keys = keys + attn_out
        keys = self.norm4(keys)

        return queries, keys

==掩码解码器的结构相对复杂,但是那张图画的和代码实现别无二致,所以只需要看懂那张图就够了。==

🔥图解架构

==至此,我们通过训练的过程了解了高层次下SAM2构建和前向传播过程,再通过sam2文件夹的解读了解了更细致的网络结构、前向传播细节以及一些技术要点。==

这是我结合源码和论文绘制的总结SAM2的图,右上角是简图,该图能较准确展示SAM2的结构和工作原理:

最详细SAM2架构

数据流动图

上图是预测的一个样例,渐变框展示了每个关键节点数据的形状,输入2个图像,每个图像有4个目标框作为提示,预测效果见下图,就是实例分割了,接下来讲一下一些关键部分是如何体现前面介绍的模型特点的:

  • 图像编码器:输入3通道图像经过分块降低到256尺寸,在图像编码器的trunk得到四个阶段的不同尺寸的特征(256,128,64,32),四组特征分别一一对应地给到neck部分的卷积,输出通道数统一为256,高分辨率high_res_feats保存128和256尺寸,64尺寸则作为图像编码特征image_embed。

  • 提示编码器:输出通道数也是256,稀疏和密集的提示编码显然参数量上有区别,输入到解码器前,图像和提示编码都会变形(其中4表示的是4个提示框);

  • 进入解码器,output_tokens由iou_token(1)、mask_tokens(4,多掩码输出功能1+3)、obj_token(1,图像任务没用就没写)一共6个token;在解码器当中不难发现紫线都是密集特征,绿色线是稀疏特征。

  • 输出层面:

    1. 最后的交叉注意力层的输出会重新取回iou_token、mask_tokens、obj_token,其中mask_tokens每一个都会对应一个MLP预测掩码,所以输出2×4×4×32;

    2. IoU scores的形状是(2→batch,4→框个数,4→多掩码输出)是因为每个掩码都要做一个IoU预测;

    3. 遮挡分数是(2→batch,4→框个数,1)是因为每个物体只要判断一次是不是被挡住(图像任务没用);

    4. 在乘法的位置输出了(2→batch,4→框个数,4→多掩码输出,256→尺寸,256→尺寸)的masks。

  • 在解码器外,如果每个目标的第0个mask稳定性分数不太行,则会在剩下3个中选择IoU最大的掩码输出(之前提到的self._dynamic_multimask_via_stability),只剩下(2,4,1,256,256),IoU分数自然也只剩一个,最后masks会还原为原图像尺寸。

  • 值得注意的是,预测时,由于每张图提示个数不同,意味着要预测的mask个数也不同,所以导致不能用张量表示,而是用列表存储,这意味着只能一张一张地传进提示编码器和掩码解码器,batch size就变成了每张图要预测的目标个数(解码器里面的2只是个示意)。

预测示意图

如下,简单描述总结每一个模块的作用:

SAM2模块功能概括

从项目结构,总结sam2/modeling下各个脚本的作用如下:

SAM2文件功能概括

🔥优劣分析

SAM2能取得较好的效果离不开以下原因:

  1. 性能提升上,最重要的是注意力机制的引入,大参数量和大规模数据集的情况下比传统卷积要强很多;

  2. 多尺度的使用,在图像抽特征和最后上采样的部分都用了多个尺度的图像特征,这是YOLO等模型使用的老方法但是依然有用;

  3. 提示功能的引入,比传统架构最具创新性的地方,它可以以不同形式传递使用者的目的,相比于预设分类再训练,这种方法很直接地满足宽泛意义上的具象需求(抽象的不能满足),SAM2对不同类型提示的编码方式也不同,密集和稀疏的提示也会被不同地使用;

  4. 图像编码器应该是用了预训练的权重,已经有很不错的能力,再加上SAM2做了一个1B的数据集,用模型本身就能源源不断地产生新的数据,相当好用;

  5. 掩码解码器的结构设计,可以不错地融合提示和图像两种模态的编码,输出解耦和YOLO系列一样能够独立输出不同模态的预测;

  6. 轻量,其速度虽然不比YOLO,但是模型依然不算大,多尺度的设计和很多降采样操作有效降低了模型参数和计算速度,对效果和速度做了不错的权衡,也提供了多个规模的模型。

总的来说SAM2上能看到ViT、YOLO等工作用到的很多好的技术,本身也在模态上进行了创新,加上transformer本身的强度,会得到很好的效果。

缺点和改进可能性:

  1. SAM1的论文中提示里面包含了文本,SAM2的示意图中去掉了文本这个提示,但是两个项目的源码都没有文本能被作为提示输入模型的迹象,SAM1原文是外接了一个CLIP的文本编码器,编码后再输入SAM作为提示,效果上对简单的词有效果,最好还是得配合点作为提示。如果文本能实现更好的预测那其实是超越了其它提示形式,因为点、框和掩码输入需要用户本身知道这些东西在哪,并且应用场景允许用户进行交互去给模型这些信息,而这些信息带有“作弊”的性质,而文本就不需要交互也不带有“作弊”性质,而是准确地反应需求,这样SAM就能部分地接近LISA的效果;

  2. 对于分割任务,SAM2本身不具有分类的能力,也就是它不关心类别只关系它长什么样,它只知道给它一个提示它就为此分割出一个掩码,而不给提示的情况下模型只会输出一个目标的掩码,所以如果对分类有需要或者想自动化生成提示,SAM1原文的方法是先用一个目标检测器ViTDet生成框作为提示,这其实也是有效的,因为如果外接一个小模型也不会有很大影响,同时这一定程度反映了SAM的局限性。

  3. 结构上是否可以简化

  4. 位置编码是否可以用更新的,里面有实现旋转位置编码,但是没有调用;

总的来说SAM2的对于图像分割领域的定位更像是一个提供提示输入的基座模型。不能把它当作一个全能的分割模型,也就是在它主打的提示分割功能之外,如果要利用它,比较合适的用法就是像LISA一样把它当作一个分割器模块去使用,还有它的视觉编码器单独拿出来也比较有价值(类似ResNet-50),因为它学过海量的图像信息。

所以我认为如果要改进SAM,一种路径是在它的提示分割领域内做效果上的提升,还有一种就是把它当做系统的一部分做出更有用的功能。

  • Copyrights © 2023-2025 LegendLeo Chen
  • 访问人数: | 浏览次数:

请我喝杯咖啡吧~

支付宝
微信