SAM2图像分割模型的微调

Segment Anything Model 2(SAM 2)作为Meta公司发布的Segment Anything Model(SAM)的升级版本,在图像和视频分割领域展现出了显著的优点和特性。

本文将使用SAM2进行预测、训练、验证、评估,任务为图像分割。参考仓库如下:

微调代码参考

SAM2源码

🔥准备工作

假设驱动和cuda(我的是12.6)已经有了,安装conda,到以下链接下载安装:

conda

安装conda,顺带下载权重文件

bash Anaconda3-2024.10-1-Linux-x86_64.sh
conda create -n leosam2 python=3.10
conda activate leosam2
pip install -e .
cd checkpoints
./download_ckpts.sh

为VSCode设置项目对应对配置:

{
    "python.autoComplete.extraPaths": [
        "/home/zigaa/leosam2/sam2"
    ],
    "python.analysis.extraPaths": [
        "/home/zigaa/leosam2/sam2"
    ]
}
{
    "version": "0.2.0",
    "configurations": [
        {
            "name": "Python 调试程序: 预测",
            "type": "debugpy",
            "request": "launch",
            "program": "/home/zigaa/leosam2/training/predict.py",
            "console": "integratedTerminal"
        }
    ]
}

🔥数据集

LabPicsV1数据集

LabPicsV1数据集

先选一个相对不大数据集LabPicsV1进行训练,用于跑通代码和调优方法的尝试。主要是对容器及其内部的材料(固、液体)进行分割,它具有一定难度且体量不大,很适合初步实验。

数据格式就是图像及其掩码,都是图片格式,其中掩码有提供实例分割的掩码,图像的三通道中,通道0和通道2分别标记的是材料和容器,所以需要将其合并处理后输入到模型。

🔥预测

🚀普通预测

读取权重和图片,就可以进行分割,绘制出效果。

import os
import numpy as np
import torch
import matplotlib.pyplot as plt
from PIL import Image
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor
import sys

os.chdir(sys.path[0])
if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")
print(f"using device: {device}")

np.random.seed(3)

def show_mask(mask, ax, random_color=False, borders = True):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        color = np.array([30/255, 144/255, 255/255, 0.6])
    h, w = mask.shape[-2:]
    mask = mask.astype(np.uint8)
    mask_image =  mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    if borders:
        import cv2
        contours, _ = cv2.findContours(mask,cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) 
        # Try to smooth contours
        contours = [cv2.approxPolyDP(contour, epsilon=0.01, closed=True) for contour in contours]
        mask_image = cv2.drawContours(mask_image, contours, -1, (1, 1, 1, 0.5), thickness=2) 
    ax.imshow(mask_image)

def show_points(coords, labels, ax, marker_size=375):
    pos_points = coords[labels==1]
    neg_points = coords[labels==0]
    ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
    ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)   

def show_box(box, ax):
    x0, y0 = box[0], box[1]
    w, h = box[2] - box[0], box[3] - box[1]
    ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0, 0, 0, 0), lw=2))    

def show_masks(image, masks, scores, point_coords=None, box_coords=None, input_labels=None, borders=True):
    for i, (mask, score) in enumerate(zip(masks, scores)):
        plt.figure(figsize=(10, 10))
        plt.imshow(image)
        show_mask(mask, plt.gca(), borders=borders)
        if point_coords is not None:
            assert input_labels is not None
            show_points(point_coords, input_labels, plt.gca())
        if box_coords is not None:
            # boxes
            show_box(box_coords, plt.gca())
        if len(scores) > 1:
            plt.title(f"Mask {i+1}, Score: {score:.3f}", fontsize=18)
        plt.axis('off')
        plt.show()

image = Image.open('../notebooks/images/truck.jpg')
image = np.array(image.convert("RGB"))

sam2_checkpoint = "../checkpoints/sam2_hiera_large.pt"
model_cfg = "../sam2_configs/sam2_hiera_l.yaml"

sam2_model = build_sam2(model_cfg, sam2_checkpoint, device=device)

predictor = SAM2ImagePredictor(sam2_model)

predictor.set_image(image)

input_point = np.array([[500, 375]])
input_label = np.array([1])
masks, scores, logits = predictor.predict(
    point_coords=input_point,
    point_labels=input_label,
    multimask_output=True,
)
sorted_ind = np.argsort(scores)[::-1]
masks = masks[sorted_ind]
scores = scores[sorted_ind]
logits = logits[sorted_ind]

show_masks(image, masks, scores, point_coords=input_point, input_labels=input_label, borders=True)

预测时报了错,是因为源码对点积注意力的计算之前设置了一些优化选项导致qkv必须是half精度,所以这里暂且注释掉这些配置。

        # Attention
        # with torch.backends.cuda.sdp_kernel(
        #     enable_flash=USE_FLASH_ATTN,
        #     # if Flash attention kernel is off, then math kernel needs to be enabled
        #     enable_math=(OLD_GPU and dropout_p > 0.0) or MATH_KERNEL_ON,
        #     enable_mem_efficient=OLD_GPU,
        # ):
        out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)

普通预测

这是给定一个目标点作为提示的预测效果,可以看到可以稳定分割出目标。

🚀分割一切

sam2 = build_sam2(model_cfg, sam2_checkpoint, device=device, apply_postprocessing=False)

# mask_generator = SAM2AutomaticMaskGenerator(sam2)
mask_generator = SAM2AutomaticMaskGenerator(
    model=sam2,
    points_per_side=64,
    points_per_batch=128,
    pred_iou_thresh=0.7,
    stability_score_thresh=0.92,
    stability_score_offset=0.7,
    crop_n_layers=0,
    box_nms_thresh=0.7,
    crop_n_points_downscale_factor=2,
    min_mask_region_area=25.0,
    use_m2m=True,
)
masks = mask_generator.generate(image)
print(f'{len(masks)}个掩码')
print(masks[0].keys())

plt.figure(figsize=(20, 20))
plt.imshow(image)
show_anns(masks)
plt.axis('off')
plt.show() 

可以通过上述方法分割全图,效果如下:

分割一切

🔥训练(参考)

用的是微调代码参考仓库的代码。这是比较简单的训练脚本,针对的是数据集LabPicsV1,主要是对容器及其内部的液体进行分割。

脚本流程:读取数据集进行预处理->构建模型加载权重->定义优化器->训练循环。

训练过程就是对SAM2的各个组件进行前向传播。

训练

🔥训练

接下来开始自主编写训练脚本。对于每个实例仅生成单个mask,能分割出图中的全部容器即可。

🚀dataset/dataloader

class LabPicsDataset(Dataset):
    def __init__(self, data_dir, transform=None):
        self.data_dir = data_dir
        self.transform = transform
        self.data = self._load_data()

    def _load_data(self):
        data = []
        image_folder = os.path.join(self.data_dir, "Simple/Train/Image")
        annotation_folder = os.path.join(self.data_dir, "Simple/Train/Instance")
        for name in os.listdir(image_folder):
            image_path = os.path.join(image_folder, name)
            annotation_path = os.path.join(annotation_folder, name[:-4] + ".png")
            data.append({"image": image_path, "annotation": annotation_path})
        return data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        entry = self.data[idx]
        image = cv2.imread(entry["image"])[..., ::-1]  # 读取图像并转换为 RGB 格式
        annotation = cv2.imread(entry["annotation"])  # 读取标注

        # 调整图像大小
        r = np.min([1024 / image.shape[1], 1024 / image.shape[0]])  # 缩放因子
        image = cv2.resize(image, (int(image.shape[1] * r), int(image.shape[0] * r)))
        annotation = cv2.resize(annotation, (int(annotation.shape[1] * r), int(annotation.shape[0] * r)), interpolation=cv2.INTER_NEAREST)

        if image.shape[0] < 1024:
            image = np.concatenate([image, np.zeros([1024 - image.shape[0], image.shape[1], 3], dtype=np.uint8)], axis=0)
            annotation = np.concatenate([annotation, np.zeros([1024 - annotation.shape[0], annotation.shape[1], 3], dtype=np.uint8)], axis=0)
        if image.shape[1] < 1024:
            image = np.concatenate([image, np.zeros([image.shape[0], 1024 - image.shape[1], 3], dtype=np.uint8)], axis=1)
            annotation = np.concatenate([annotation, np.zeros([annotation.shape[0], 1024 - annotation.shape[1], 3], dtype=np.uint8)], axis=1)

        # 合并材料和容器标注
        mat_map = annotation[:, :, 0]
        ves_map = annotation[:, :, 2]
        mat_map[mat_map == 0] = ves_map[mat_map == 0] * (mat_map.max() + 1)

        # 获取二值掩码和点
        inds = np.unique(mat_map)[1:]
        if len(inds) > 0:
            ind = inds[np.random.randint(len(inds))]
            # mask = (mat_map == ind).astype(np.uint8)
            mask = (mat_map > 0).astype(np.uint8)           # 全部mask合成一个(全部预测出来)
            coords = np.argwhere(mask > 0)
            yx = coords[np.random.randint(len(coords))]
            point = [[yx[1], yx[0]]]
        else:
            # 如果没有有效标注,返回全零掩码和随机点
            mask = np.zeros((image.shape[:2]), dtype=np.uint8)
            point = [[np.random.randint(0, 1024), np.random.randint(0, 1024)]]

        if self.transform:
            image = self.transform(image)

        return image, mask, np.array(point, dtype=np.float32), np.ones([1])

针对这个数据集做的DataSet类,读取标注文件对容器和材料的通道进行合并得到标注的mask,单个样本包含输入的原图(3*H*W)、标注的mask(H*W)、提示点(坐标列表)、提示点类型(正/负,这里暂时只用正)。

这里mask = (mat\_map > 0).astype(np.uint8)表示的是把该图中所有的mask合并成同一个,因为我们目前希望模型能一次性找出所有目标(不太需要实例级别的分割,实例分割推理速度也会很慢),而如果每次都随机选一个mask让它预测会让模型预测混乱。

🚀损失函数

Dice Loss,是一种常用于分割任务的损失函数,特别适用于处理分割掩码(masks)的二值分类问题。Dice 损失类似于广义的交并比(IoU),用于衡量预测掩码和真实掩码之间的相似度。

def dice_loss(inputs, targets, num_objects, loss_on_multimask=False):
    """
    Dice Loss
    """
    inputs = inputs.sigmoid()
    if loss_on_multimask:
        # inputs and targets are [N, M, H, W] where M corresponds to multiple predicted masks
        assert inputs.dim() == 4 and targets.dim() == 4
        # flatten spatial dimension while keeping multimask channel dimension
        inputs = inputs.flatten(2)
        targets = targets.flatten(2)
        numerator = 2 * (inputs * targets).sum(-1)
    else:
        inputs = inputs.flatten(1)
        targets = targets.flatten(1)
        numerator = 2 * (inputs * targets).sum(1)
    denominator = inputs.sum(-1) + targets.sum(-1)
    loss = 1 - (numerator + 1) / (denominator + 1)
    if loss_on_multimask:
        return loss / num_objects
    return loss.sum() / num_objects

Focal Loss,在交叉熵的基础上进行改进的一种损失,适合密集预测。

def sigmoid_focal_loss(
    inputs,
    targets,
    num_objects,
    alpha: float = 0.25,
    gamma: float = 2,
    loss_on_multimask=False,
):
    """
    Focal Loss 
    """
    prob = inputs.sigmoid()
    ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
    p_t = prob * targets + (1 - prob) * (1 - targets)
    loss = ce_loss * ((1 - p_t) ** gamma)

    if alpha >= 0:
        alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
        loss = alpha_t * loss

    if loss_on_multimask:
        # loss is [N, M, H, W] where M corresponds to multiple predicted masks
        assert loss.dim() == 4
        return loss.flatten(2).mean(-1) / num_objects  # average over spatial dims
    return loss.mean((1, 2, 3)).sum() / num_objects

dIoU Loss,d指的是差值,模型会预测一个IoU,该损失就是反映这个预测值的损失,(IoU是计算交集和并集之比,反映预测mask与原mask的重合程度)。

def diou_loss(
    inputs, targets, pred_ious, num_objects, loss_on_multimask=False, use_l1_loss=False
):
    """
    交并比损失
    Args:
        inputs: A float tensor of arbitrary shape.
                The predictions for each example.
        targets: A float tensor with the same shape as inputs. Stores the binary
                 classification label for each element in inputs
                (0 for the negative class and 1 for the positive class).
        pred_ious: A float tensor containing the predicted dIoUs scores per mask
        num_objects: Number of objects in the batch
        loss_on_multimask: True if multimask prediction is enabled
        use_l1_loss: Whether to use L1 loss is used instead of MSE loss
    Returns:
        dIoU loss tensor
    """
    assert inputs.dim() == 4 and targets.dim() == 4
    pred_mask = inputs.flatten(2) > 0
    gt_mask = targets.flatten(2) > 0
    area_i = torch.sum(pred_mask & gt_mask, dim=-1).float()
    area_u = torch.sum(pred_mask | gt_mask, dim=-1).float()
    actual_ious = area_i / torch.clamp(area_u, min=1.0)

    if use_l1_loss:
        loss = F.l1_loss(pred_ious, actual_ious, reduction="none")
    else:
        loss = F.mse_loss(pred_ious, actual_ious, reduction="none")
    if loss_on_multimask:
        return loss / num_objects
    return loss.sum() / num_objects

IoU Loss,下面这个就是正儿八经的IoU损失,直接反映预测mask与原mask的重合程度,这里我将其映射到tanh(值域0-1)当中,乘以3是为了映射IoU到中间梯度比较大的部分,以更好区分不同的预测效果。

def iou_loss(
    inputs, targets, num_objects, loss_on_multimask=False
):
    """
    交并比损失
    Args:
        inputs: A float tensor of arbitrary shape.
                The predictions for each example.
        targets: A float tensor with the same shape as inputs. Stores the binary
                 classification label for each element in inputs
                (0 for the negative class and 1 for the positive class).
        num_objects: Number of objects in the batch
        loss_on_multimask: True if multimask prediction is enabled
    Returns:
        IoU loss tensor
    """
    assert inputs.dim() == 4 and targets.dim() == 4
    pred_mask = inputs.flatten(2) > 0
    gt_mask = targets.flatten(2) > 0
    area_i = torch.sum(pred_mask & gt_mask, dim=-1).float()
    area_u = torch.sum(pred_mask | gt_mask, dim=-1).float()
    actual_ious = area_i / torch.clamp(area_u, min=1.0)
    loss = 1 - torch.tanh(actual_ious * 3)
    if loss_on_multimask:
        return loss / num_objects
    return loss.sum() / num_objects

对上面的损失加权求和就是总的损失了。

def total_loss(
    inputs,
    targets,
    pred_ious,
    num_objects,
    alpha: float = 0.25,
    gamma: float = 2,
    loss_on_multimask: bool = False,
    use_l1_loss: bool = False,
    dice_weight: float = 1.0,
    focal_weight: float = 1.0,
    diou_weight: float = 1.0,
    iou_weight: float = 1.0,
) -> Tuple[torch.Tensor, Dict]:

    # Compute individual losses
    dice = dice_loss(inputs, targets, num_objects, loss_on_multimask)
    focal = sigmoid_focal_loss(inputs, targets, num_objects, alpha, gamma, loss_on_multimask)
    diou = diou_loss(inputs, targets, pred_ious, num_objects, loss_on_multimask, use_l1_loss)
    iou = iou_loss(inputs, targets, num_objects, loss_on_multimask)

    # Combine losses with weights
    total = dice_weight * dice + focal_weight * focal + diou_weight * diou + iou * iou_weight

    return total, {"dice": dice_weight * dice, "focal": focal_weight * focal, "diou": diou_weight * diou, "iou": iou_weight * iou}

🚀训练脚本

import os
import numpy as np
import cv2
import torch
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torch.amp import GradScaler, autocast
from torch.utils.tensorboard import SummaryWriter

from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor
from training.loss_fn import total_loss
from training.dataset import LabPicsDataset
import sys

os.chdir(sys.path[0])

前向传播

就是逐步地将数据集的batch输入到模型的各个部分进行传播,可以设定有/无提示的推理方法。

def forward_net(predictor, image_list, use_prompt=False, input_point=None, input_label=None):
    predictor.set_image_batch(image_list)  # 应用图像编码器
    if use_prompt:
        mask_input, unnorm_coords, labels, unnorm_box = predictor._prep_prompts(input_point, input_label, box=None, mask_logits=None, normalize_coords=True)      # 有提示
        sparse_embeddings, dense_embeddings = predictor.model.sam_prompt_encoder(points=(unnorm_coords, labels), boxes=None, masks=None)
    else:
        sparse_embeddings, dense_embeddings = predictor.model.sam_prompt_encoder(points=None, boxes=None, masks=None)       # 无提示
    
    high_res_features = [feat_level[-1].unsqueeze(0) for feat_level in predictor._features["high_res_feats"]]
    low_res_masks, prd_scores, _, _ = predictor.model.sam_mask_decoder(
        image_embeddings=predictor._features["image_embed"],
        image_pe=predictor.model.sam_prompt_encoder.get_dense_pe(),
        sparse_prompt_embeddings=sparse_embeddings,
        dense_prompt_embeddings=dense_embeddings,
        multimask_output=False,
        repeat_image=False,
        high_res_features=high_res_features
    )
    prd_masks = predictor._transforms.postprocess_masks(low_res_masks, predictor._orig_hw[-1])  # 将掩码上采样到原始图像分辨率
    return prd_masks, prd_scores, torch.sigmoid(prd_masks[:, 0])

训练函数

训练一个epoch的函数,前向传播+损失计算+反向传播+梯度更新+可视化+学习率更新。

def train_epoch(predictor, train_dataloader, epoch, optimizer, scaler, scheduler, writer, use_prompt=False, accumulation_steps=4):
    step = len(train_dataloader) * epoch
    for batch_idx, (image, mask, input_point, input_label) in enumerate(train_dataloader):
        gt_mask = mask.float().to(device).unsqueeze(1)
        step += 1
        with autocast('cuda', torch.bfloat16):  # 混合精度训练
            prd_masks, prd_scores, prd_mask = forward_net(predictor, image, use_prompt, input_point if use_prompt else None, input_label if use_prompt else None)
            loss, losses = total_loss(inputs=prd_masks, targets=gt_mask, pred_ious=prd_scores, num_objects=image.shape[0],
                                    dice_weight=0.4, focal_weight=12.0, diou_weight=5.0, iou_weight=2.0)

        loss = loss / accumulation_steps

        # 反向传播/梯度累积
        predictor.model.zero_grad()
        scaler.scale(loss).backward()
        if (batch_idx + 1) % accumulation_steps == 0:
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()

        # 可视化
        if (batch_idx + 1) % 50 == 0:
            print(f"Epoch {epoch}, Batch {batch_idx + 1}, Loss: {loss.item() * accumulation_steps: 3f},", 
                  f"Dice Loss: {losses['dice'].item(): 2f}, Focal Loss: {losses['focal'].item(): 2f}, dIOU Loss: {losses['diou'].item(): 2f}, IOU Loss: {losses['iou'].item(): 2f}")
        writer.add_scalar('train/total loss', loss.item() * accumulation_steps, step)
        writer.add_scalar('lr', optimizer.param_groups[0]['lr'], step)
        for key, value in losses.items():
            writer.add_scalar(f'train/{key}loss', value.item(), step)
            
    scheduler.step()            # 每代结束学习率更新

验证函数

验证整个测试集,验证时需要关闭梯度,流程就是前向传播+损失计算+可视化。

def validate(predictor, val_dataloader, epoch, writer, use_prompt=False):
    predictor.model.sam_mask_decoder.train(False)
    total_val_loss = 0.0
    with torch.no_grad():
        for batch_idx, (image, mask, input_point, input_label) in enumerate(val_dataloader):    
            gt_mask = mask.float().to(device).unsqueeze(1)
            prd_masks, prd_scores, prd_mask = forward_net(predictor, image, use_prompt, input_point if use_prompt else None, input_label if use_prompt else None)
            val_loss, val_losses = total_loss(inputs=prd_masks, targets=gt_mask, pred_ious=prd_scores, num_objects=image.shape[0],
                                    dice_weight=0.4, focal_weight=12.0, diou_weight=5.0, iou_weight=2.0)
            total_val_loss += val_loss.item()
    total_val_loss /= len(val_dataloader)
    predictor.model.sam_mask_decoder.train(True)
    print(f"Epoch {epoch + 1}, Val Loss: {total_val_loss:.3f}")
    writer.add_scalar('val/total_loss', total_val_loss, epoch + 1)

训练流程

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
batch_size = 32
model_scale_dict = {"b+": "base_plus", "l": "large", "s": "small"}
model_scale = "l"

# 数据集和数据加载器
train_dataset = LabPicsDataset("../assets/LabPicsV1/Simple/Train/", transform=transforms.ToTensor())
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
val_dataset = LabPicsDataset("../assets/LabPicsV1/Simple/Test", transform=transforms.ToTensor())
val_dataloader = DataLoader(val_dataset, batch_size=16, shuffle=False, num_workers=4)

# 加载模型
sam2_checkpoint = f"../checkpoints/sam2_hiera_{model_scale_dict[model_scale]}.pt"
model_cfg = f"../sam2_configs/sam2_hiera_{model_scale}.yaml"
sam2_model = build_sam2(model_cfg, sam2_checkpoint, device="cuda")
predictor = SAM2ImagePredictor(sam2_model)
# predictor.model.load_state_dict(torch.load("./output/exp2/model_epoch_40.pt"))

# 设置训练参数
predictor.model.sam_mask_decoder.train(True)
predictor.model.sam_prompt_encoder.train(False)
predictor.model.image_encoder.train(False)
optimizer = torch.optim.AdamW(params=predictor.model.parameters(), lr=1e-5, weight_decay=4e-5)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=50, eta_min=1e-8)
scaler = GradScaler('cuda')
writer = SummaryWriter(log_dir="output")
use_prompt = False
epoches = 500

print("train on")   
# 训练循环
for epoch in range(epoches):
    # 训练环节
    train_epoch(predictor, train_dataloader, epoch, optimizer, scaler, scheduler, writer, use_prompt)
    # 验证环节
    if (epoch + 1) % 10 == 0:
        validate(predictor, val_dataloader, epoch, writer, use_prompt)

    if (epoch + 1) % 20 == 0:
        # 每20个 epoch 保存一次模型
        torch.save(predictor.model.state_dict(), f"./output/model_epoch_{epoch + 1}.pt")

writer.close()

训练流程和其他深度学习模型类似,使用了一些训练技巧。

技术

余弦学习率

  • 冻结:冻结了image encoderprompt encoder,因为编码器已经学习到不错的图片和提示抽特征能力了,冻结它们可以加快计算。
  • 余弦学习率调度器:学习率将随着epoch按余弦曲线下降后上升,最值是优化器optimizer填入的lr和调度器scheduler填入的eta_minT_max指的是余弦的半个周期(最大值降到最小值)。
  • 混合精度训练: with autocast('cuda', torch.float16):可以使得前向传播时使用FP16精度,加快速度。反向传播梯度更新时恢复到FP32精度保证准确性。
  • 可视化:实现了用tensorboard对训练过程的变量(损失、学习率等)进行可视化(如下图)。训练过程中,在终端运行tensorboard --logdir=日志相对位置,然后浏览器进入给你的IP地址就可以看到以下界面。
  • 梯度累积:为了在处理大模型或有限的 GPU 内存时,模拟大批次的训练效果,多次反向传播得到的梯度积累accumulation_steps步后再进行更新。

🔥评估

🚀评价指标

IoU,前面损失函数里面提到过,只不过不需要映射到tanh了。

def iou(inputs, targets, num_objects):
    """
    计算交并比(IoU)
    :param inputs: 预测的分割结果,形状为 [H, W]
    :param targets: 真实的分割标签,形状为 [H, W]
    :param num_objects: mask 个数(类别数)
    :return: IoU 值
    """
    iou_list = []
    for i in range(1, num_objects + 1):  # 假设背景是 0,从 1 开始计算各个对象的 IoU
        pred_mask = (inputs == i)
        true_mask = (targets == i)
        intersection = np.logical_and(pred_mask, true_mask).sum()
        union = np.logical_or(pred_mask, true_mask).sum()
        if union == 0:
            iou_list.append(1.0)  # 如果 union 为 0,IoU 定义为 1(特殊情况处理)
        else:
            iou_list.append(intersection / union)
    return np.mean(iou_list)

DICE,前面也用过,效果接近IoU。

def dice(inputs, targets, num_objects):
    """
    计算 Dice 系数
    :param inputs: 预测的分割结果,形状为 [H, W]
    :param targets: 真实的分割标签,形状为 [H, W]
    :param num_objects: mask 个数(类别数)
    :return: Dice 系数
    """
    dice_list = []
    for i in range(1, num_objects + 1):  # 假设背景是 0,从 1 开始计算各个对象的 Dice
        pred_mask = (inputs == i)
        true_mask = (targets == i)
        intersection = np.logical_and(pred_mask, true_mask).sum()
        dice_value = (2. * intersection) / (pred_mask.sum() + true_mask.sum())
        dice_list.append(dice_value)
    return np.mean(dice_list)

像素准确度(PA),计算每个像素是否预测准确,是比较基础的一个指标。

def pixel_accuracy(inputs, targets, num_objects):
    """
    计算像素准确度
    :param inputs: 预测的分割结果,形状为 [H, W]
    :param targets: 真实的分割标签,形状为 [H, W]
    :param num_objects: mask 个数(类别数)(虽然这里用不到,但保持接口一致)
    :return: 像素准确度
    """
    correct = (inputs == targets).sum()
    total = inputs.size
    return correct / total

Boundary F1-score,衡量分割区域的边界的预测质量。

def boundary_f1_score(inputs, targets, num_objects):
    """
    计算边界 F1 分数
    :param inputs: 预测的分割结果,形状为 [H, W]
    :param targets: 真实的分割标签,形状为 [H, W]
    :param num_objects: mask 个数(类别数)
    :return: 边界 F1 分数
    """
    # 创建一个结构元素用于边界检测(这里使用 3x3 的十字结构)
    struct = np.array([[0, 1, 0],
                       [1, 1, 1],
                       [0, 1, 0]])

    f1_list = []
    for i in range(1, num_objects + 1):  # 对每个对象计算边界 F1 分数
        # 获取预测和真实的目标掩码
        pred_mask = (inputs == i).astype(np.uint8)[0]
        true_mask = (targets == i).astype(np.uint8)[0]

        # 计算边界
        pred_boundary = binary_dilation(pred_mask, struct) ^ binary_erosion(pred_mask, struct)
        true_boundary = binary_dilation(true_mask, struct) ^ binary_erosion(true_mask, struct)

        # 计算 TP、FP、FN
        tp = np.logical_and(pred_boundary, true_boundary).sum()
        fp = np.logical_and(pred_boundary, np.logical_not(true_boundary)).sum()
        fn = np.logical_and(np.logical_not(pred_boundary), true_boundary).sum()

        # 计算精确度和召回率
        if tp + fp == 0:
            precision = 0.0
        else:
            precision = tp / (tp + fp)

        if tp + fn == 0:
            recall = 0.0
        else:
            recall = tp / (tp + fn)

        # 计算 F1 分数
        if precision + recall == 0:
            f1 = 0.0
        else:
            f1 = 2 * (precision * recall) / (precision + recall)

        f1_list.append(f1)

    return np.mean(f1_list)

🚀评估脚本

import os
import numpy as np
import torch
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import sys

from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor
from training.metric import *
from training.dataset import LabPicsDataset

os.chdir(sys.path[0])

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
batch_size = 16
model_scale_dict = {"b+": "base_plus", "l": "large", "s": "small"}
model_scale = "s"

# 数据集和数据加载器
dataset = LabPicsDataset("../assets/LabPicsV1/Simple/Test", transform=transforms.ToTensor())
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=4)

# 加载模型
sam2_checkpoint = f"../checkpoints/sam2_hiera_{model_scale_dict[model_scale]}.pt"
model_cfg = f"../sam2_configs/sam2_hiera_{model_scale}.yaml"
sam2_model = build_sam2(model_cfg, sam2_checkpoint, device="cuda")
predictor = SAM2ImagePredictor(sam2_model)
predictor.model.load_state_dict(torch.load(f"./output/exp{4}/model_epoch_{120}.pt"))

use_prompt = False

iou_scores = []
dice_scores = []
pixel_acc_scores = []
boundary_f1_scores = []

for batch_idx, (image, mask, input_point, input_label) in enumerate(dataloader):    
    gt_mask = mask.float().unsqueeze(1)
    gt_mask = [gt_mask[i].numpy() for i in range(gt_mask.shape[0])]
    predictor.set_image_batch(image)
    masks, scores, logits = predictor.predict_batch(
        multimask_output=False,
    )

    num_objects = 1

    for output, target in zip(masks, gt_mask):
        iou_scores.append(iou(output, target, num_objects))
        dice_scores.append(dice(output, target, num_objects))
        pixel_acc_scores.append(pixel_accuracy(output, target, num_objects))
        boundary_f1_scores.append(boundary_f1_score(output, target, num_objects))

print("Average IoU:", np.mean(iou_scores))
print("Average Dice:", np.mean(dice_scores))
print("Average Pixel Accuracy:", np.mean(pixel_acc_scores))
print("Average Boundary F1 Score:", np.mean(boundary_f1_scores))

流程很简单,就是批量预测后计算指标得分。

🔥实验

🚀跑通训练

跑了40代试了一下,是可以收敛的,各个损失都是合理的收敛曲线。

40代large训练曲线

微调前后对比

左边的图是训练前,右图是训练的后的预测效果(都是无prompt),可以看到在不给提示的情况下,原SAM2分割效果不好,训练后有很好的改善。这只是训练过程的初步完善,后续还要做更多的改进。推理时间为。最大的large模型推理时间为0.1s量级。

跑500代观察曲线,可以看到在120代左右验证损失达到最低,此后训练损失还是会下降,说明后面已经发生过拟合了。合适的训练迭代数就是150以内。

参数 参数值 参数 参数值
epoch 500 model_scale large(224M)
froze_image_encoder True froze_prompt_encoder True
batch size 32 lr_max 1e-5
weight decay 4e-5 lr_min 1e-8
dice_weight 0.4 T_max 50
focal_weight 12.0 diou_weight 5.0
iou_weight 2.0 use_prompt False
训练精度 FP16 accumulation_steps 1

500代large训练曲线

🚀冻结对照

参数 参数值 参数 参数值
epoch 300 model_scale base_plus(80.8M)
froze_image_encoder True/False froze_prompt_encoder True
batch size 32 lr_max 1e-5
weight decay 4e-5 lr_min 1e-8
dice_weight 0.4 T_max 50
focal_weight 12.0 diou_weight 5.0
iou_weight 2.0 use_prompt False
训练精度 FP16 accumulation_steps 1

冻结对照

exp3和exp5分别表示b+模型冻结/未冻结图像编码器的情况,可以看到冻结图像编码器能得到更好的效果,说明编码器经过十亿级别的数据集预训练已经能很好地对图像抽取特征了,不需要在下游数据集进行参数更新,那样反而会丢失之前的能力。

🚀prompt对照

参数 参数值 参数 参数值
epoch 300 model_scale base_plus(80.8M)
froze_image_encoder True froze_prompt_encoder True
batch size 32 lr_max 1e-5
weight decay 4e-5 lr_min 1e-8
dice_weight 0.4 T_max 50
focal_weight 12.0 diou_weight 5.0
iou_weight 2.0 use_prompt False/True
训练精度 FP16 accumulation_steps 1

prompt对照

exp3未使用prompt,exp6表示使用prompt,即对mask标注内随机选点作为提示输入模型,同时验证时也使用prompt,模型的损失明显降低,说明prompt的使用会明显有效提升模型在该任务和数据集的分割能力。

prompt对照2

exp7则是在验证时不使用prompt输入,可以看到如果用提示训练的SAM2在使用时不使用提示,其效果会很差,所以是否使用提示训练取决于使用场景的需求。

🚀训练精度对照

参数 参数值 参数 参数值
epoch 500 model_scale base_plus(80.8M)
froze_image_encoder True froze_prompt_encoder True
batch size 32/32/16 lr_max 1e-5
weight decay 4e-5 lr_min 1e-8
dice_weight 0.4 T_max 50
focal_weight 12.0 diou_weight 5.0
iou_weight 2.0 use_prompt False
训练精度 FP16/BF16/FP32 accumulation_steps 1

在进行前向传播时模型使用半精度进行计算会降低显存压力、提升运算速度,同时也不会很显著地降低模型能力。

训练精度对照

exp8、exp9和exp10表示float16、bfloat16和float32精度下进行训练的效果,由于float32增加了显存压力,只能用更小的batch size,可以看到float32甚至不会有更好的效果。

🚀梯度累积对照

参数 参数值 参数 参数值
epoch 500 model_scale base_plus(80.8M)
froze_image_encoder True froze_prompt_encoder True
batch size 32 lr_max 1e-5
weight decay 4e-5 lr_min 1e-8
dice_weight 0.4 T_max 50
focal_weight 12.0 diou_weight 5.0
iou_weight 2.0 use_prompt False
训练精度 BF16 accumulation_steps 1/4

exp8不使用梯度累积,exp11设置4步的梯度累积模拟32*4的虚拟batch size,在保持效果差别不大的情况下,过拟合程度显著降低,可能使用更大的学习率会有更好的效果。

梯度累积对照

🚀评估

使用三个不同规模的模型进行如下训练,取验证损失最低的模型进行评估测试(并不一定是最优模型,但是是该训练条件下暂时最优)。

参数 参数值 参数 参数值
epoch 500 model_scale s/b+/l
froze_image_encoder True froze_prompt_encoder True
batch size 32 lr_max 1e-5
weight decay 4e-5 lr_min 1e-8
dice_weight 0.4 T_max 50
focal_weight 12.0 diou_weight 5.0
iou_weight 2.0 use_prompt False
训练精度 FP16 accumulation_steps 1
规模 IoU Dice Pixel Accuracy Boundary F1 Score
s 0.867 0.918 0.984 0.276
b+ 0.875 0.922 0.984 0.283
l 0.876 0.925 0.985 0.288

可以看到模型规模越大测试的效果越好,PA指标的区分度不是很大。

  • Copyrights © 2023-2025 LegendLeo Chen
  • 访问人数: | 浏览次数:

请我喝杯咖啡吧~

支付宝
微信