LISA图像分割模型的上手和微调

LISA系列

本文记录LISA系列的相关内容,任务为图像分割。LISA、LISA++仓库如下(两个分支):

LISA

LISA的这个仓库给了很完整的上手介绍,以及不算太复杂的训练、预测脚本,基本可以在其基础上做适配,而不用从头搞。

LISA架构

LISA用的是多模态模型进行分割,LISA架构如上,图文都输入到多模态大模型中,以LoRA的方式训练,输出嵌入表达;原图经过一个视觉网络后和LLM输出一起输入到解码器,得到最后的mask。

多模态LLM用的LLaVA,视觉网络是用的SAM的编码器(本质是ViT)。

LISA++架构

LISA++则是增加了实例分割能力,LISA只能一次输出单个掩码,这个掩码会包含所有需要的对象,而LISA++会为每个符合要求的目标生成一个掩码。

🔥准备工作

🚀配置

安装库(可以先整一个虚拟环境),大部分库版本要对,要不然很容易报错,特别是transformers、gradio、numpy这些,其他库实在不行装新版试试。

pip install -r requirements.txt
pip install flash-attn --no-build-isolation

🚀下载权重

下载权重,以下是比较新的,可以选7B或者13B(本文用的7B,因为4090D显存刚好不够跑bf16的13B),放到自定义的位置。其中LISA应该是直接包含LLaVA和SAM的,也就是网络结构当中的全部。如果只是预测和评估就只需要下载LISA和LLaVA视觉主干(理论上第二个不用下,但是代码里面似乎是单独读取权重的);如果要微调就需要LISA(或LLaVA)、LLaVA视觉主干、SAM-VIT-H。

下载完成后,打开LISA权重文件夹中的config.json文件,里面的vision_tower的路径应当改成视觉主干文件夹的路径。

🚀数据集

数据集用的还是LabPicsV1,是一个相对不大数据集,用于跑通代码和调优方法的尝试。主要是对容器及其内部的材料(固、液体)进行分割,它具有一定难度且体量不大,很适合初步实验。

🔥LISA预测

🚀预测脚本

LISA提供的脚本可以直接完成对话预测,只需要运行其中的chat.py脚本。

参数如下,最主要是LLM的路径version、精度precision、加载4/8bit模型、视觉主干的路径vision-tower。

def parse_args(args):
    parser = argparse.ArgumentParser(description="LISA chat")
    parser.add_argument("--version", default="./weight/lisa")
    parser.add_argument("--vis_save_path", default="./vis_output/my", type=str)
    parser.add_argument(
        "--precision",
        default="bf16",
        type=str,
        choices=["fp32", "bf16", "fp16"],
        help="precision for inference",
    )
    parser.add_argument("--image_size", default=1024, type=int, help="image size")
    parser.add_argument("--model_max_length", default=512, type=int)
    parser.add_argument("--lora_r", default=8, type=int)
    parser.add_argument(
        "--vision-tower", default="./weight/clip-vit-large-patch14", type=str
    )
    parser.add_argument("--local-rank", default=0, type=int, help="node rank")
    parser.add_argument("--load_in_8bit", action="store_true", default=False)
    parser.add_argument("--load_in_4bit", action="store_true", default=False)
    parser.add_argument("--use_mm_start_end", action="store_true", default=True)
    parser.add_argument(
        "--conv_type",
        default="llava_v1",
        type=str,
        choices=["llava_v1", "llava_llama_2"],
    )
    return parser.parse_args(args)

主要流程和大模型的预测差不多:

  1. 准备:加载分词器、配置模型参数、加载模型(LLM和视觉主干);

  2. 对话:文本构建、读取图片并嵌入、文本嵌入、图文输入到模型、绘制图片并保存。

对话和预测

运行脚本后,输入文本和图片路径即可,英文效果好。

🚀预测效果

7B模型在4090D上运行速度最快应该能到0.5s左右,慢的会2s以上(一般就是0.5-1s),算是相当快了(SAM2可能都得0.12s)。

对话效果

对话效果2

who is controlling the uav?

可以看到模型对文本的理解能力比较好地应用到了图像领域,这种抽象的理解能力是传统模型不可能做到的。

对话效果4

find the usb-a

对于比较不常见的需求,模型没有办法认识,这跟数据集的广度有关系,如果有需求可以进行LoRA微调,有提供合并权重的脚本。

🚀gradio可视化UI

app.py脚本使用gradio这个webUI库实现了可视化预测,gr.Interface是按照gradio官方预设的顺序渲染这些组件:

demo = gr.Interface(
    inference,
    inputs=[
        gr.Textbox(lines=1, placeholder=None, label="Text Instruction"),
        gr.Image(type="filepath", label="Input Image"),
    ],
    outputs=[
        gr.Image(type="pil", label="Segmentation Output"),
        gr.Textbox(lines=1, placeholder=None, label="Text Output"),
    ],
    title=title,
    description=description,
    article=article,
    examples=examples,
    allow_flagging="auto",
)

gradio界面

可以看到有预设好的标题、描述,中间是交互界面,底下有一些样例,点击可以直接装填到交互区域,gradio很适合作为部署的一种方案。

🔥LISA评估

🚀dataset和metric

下面是数据集的类。

import numpy as np
import cv2
import torch
import random
from torch.utils.data import Dataset, DataLoader
import os
from transformers import CLIPImageProcessor

from utils.utils import (DEFAULT_IM_END_TOKEN, DEFAULT_IM_START_TOKEN,
                         DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX)
from model.llava import conversation as conversation_lib
from utils.utils import ANSWER_LIST, SHORT_QUESTION_LIST

# 定义 Dataset 类
class LabPicsDataset(Dataset):
    def __init__(self, data_dir, transform=None):
        self.data_dir = data_dir
        self.transform = transform
        self.data = self._load_data()

    def _load_data(self):
        data = []
        image_folder = os.path.join(self.data_dir, "Image")
        annotation_folder = os.path.join(self.data_dir, "Instance")
        for name in os.listdir(image_folder):
            image_path = os.path.join(image_folder, name)
            annotation_path = os.path.join(annotation_folder, name[:-4] + ".png")
            data.append({"image": image_path, "annotation": annotation_path})
        return data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        entry = self.data[idx]
        image = cv2.imread(entry["image"])[..., ::-1]  # 读取图像并转换为 RGB 格式
        annotation = cv2.imread(entry["annotation"])  # 读取标注

        # 调整图像大小
        r = np.min([1024 / image.shape[1], 1024 / image.shape[0]])  # 缩放因子
        image = cv2.resize(image, (int(image.shape[1] * r), int(image.shape[0] * r)))
        annotation = cv2.resize(annotation, (int(annotation.shape[1] * r), int(annotation.shape[0] * r)), interpolation=cv2.INTER_NEAREST)

        if image.shape[0] < 1024:
            image = np.concatenate([image, np.zeros([1024 - image.shape[0], image.shape[1], 3], dtype=np.uint8)], axis=0)
            annotation = np.concatenate([annotation, np.zeros([1024 - annotation.shape[0], annotation.shape[1], 3], dtype=np.uint8)], axis=0)
        if image.shape[1] < 1024:
            image = np.concatenate([image, np.zeros([image.shape[0], 1024 - image.shape[1], 3], dtype=np.uint8)], axis=1)
            annotation = np.concatenate([annotation, np.zeros([annotation.shape[0], 1024 - annotation.shape[1], 3], dtype=np.uint8)], axis=1)

        # 合并材料和容器标注
        mat_map = annotation[:, :, 0]
        ves_map = annotation[:, :, 2]
        mat_map[mat_map == 0] = ves_map[mat_map == 0] * (mat_map.max() + 1)

        # 获取二值掩码和点
        inds = np.unique(mat_map)[1:]
        if len(inds) > 0:
            ind = inds[np.random.randint(len(inds))]
            # mask = (mat_map == ind).astype(np.uint8)
            mask = (mat_map > 0).astype(np.uint8)           # 全部mask合成一个(全部预测出来)
            coords = np.argwhere(mask > 0)
            yx = coords[np.random.randint(len(coords))]
            point = [[yx[1], yx[0]]]
        else:
            # 如果没有有效标注,返回全零掩码和随机点
            mask = np.zeros((image.shape[:2]), dtype=np.uint8)
            point = [[np.random.randint(0, 1024), np.random.randint(0, 1024)]]

        if self.transform:
            image = self.transform(image)

        return image, mask, np.array(point, dtype=np.float32), np.ones([1])

评估指标和SAM2文档里面的一样,有IoU、GIoU、CIoU、Dice、PA、Boundary F1。

import numpy as np
from scipy.ndimage import binary_dilation, binary_erosion
from scipy.spatial import ConvexHull

def iou(inputs, targets, num_objects):
    """
    计算交并比(IoU)
    :param inputs: 预测的分割结果,形状为 [H, W]
    :param targets: 真实的分割标签,形状为 [H, W]
    :param num_objects: mask 个数(类别数)
    :return: IoU 值
    """
    iou_list = []
    for i in range(1, num_objects + 1):  # 假设背景是 0,从 1 开始计算各个对象的 IoU
        pred_mask = (inputs == i)
        true_mask = (targets == i)
        intersection = np.logical_and(pred_mask, true_mask).sum()
        union = np.logical_or(pred_mask, true_mask).sum()
        if union == 0:
            iou_list.append(1.0)  # 如果 union 为 0,IoU 定义为 1(特殊情况处理)
        else:
            iou_list.append(intersection / union)
    return np.mean(iou_list)

def giou(inputs, targets, num_objects):
    """
    计算 GIoU
    :param inputs: 预测的分割结果,形状为 [H, W]
    :param targets: 真实的分割标签,形状为 [H, W]
    :param num_objects: mask 个数(类别数)
    :return: GIoU 值
    """
    giou_list = []
    for i in range(1, num_objects + 1):  # 假设背景是 0,从 1 开始计算各个对象的 GIoU
        pred_mask = (inputs == i)
        true_mask = (targets == i)
        
        # 计算交集和并集
        intersection = np.logical_and(pred_mask, true_mask).sum()
        union = np.logical_or(pred_mask, true_mask).sum()
        
        # 如果 union 为 0,直接返回 1.0(特殊情况处理)
        if union == 0:
            giou_list.append(1.0)
            continue

        pred_coords = np.argwhere(pred_mask)
        true_coords = np.argwhere(true_mask)
        all_coords = np.concatenate([pred_coords, true_coords], axis=0)
        if len(all_coords) == 0:
            giou_list.append(1.0)
            continue

        # 使用凸包代替目标检测中的最小外接矩形
        hull = ConvexHull(all_coords[:, 1:])
        hull_area = hull.volume  # 2D convex hull volume is the area
        
        # 计算 GIoU
        giou = intersection / union - (hull_area - union) / hull_area
        giou_list.append(giou)
    
    return np.mean(giou_list)


def ciou(inputs, targets, num_objects):
    """
    计算 CIoU
    :param inputs: 预测的分割结果,形状为 [H, W]
    :param targets: 真实的分割标签,形状为 [H, W]
    :param num_objects: mask 个数(类别数)
    :return: CIoU 值
    """
    ciou_list = []
    for i in range(1, num_objects + 1):
        pred_mask = (inputs == i)
        true_mask = (targets == i)
        
        # I和U
        intersection = np.logical_and(pred_mask, true_mask).sum()
        union = np.logical_or(pred_mask, true_mask).sum()
        
        if union == 0:
            ciou_list.append(1.0)
            continue
        
        pred_coords = np.argwhere(pred_mask)
        true_coords = np.argwhere(true_mask)
        
        if len(pred_coords) == 0 or len(true_coords) == 0:
            ciou_list.append(0.0)
            continue
        
        # 外接矩形
        x_coords = np.concatenate([pred_coords[:, 1], true_coords[:, 1]])
        y_coords = np.concatenate([pred_coords[:, 2], true_coords[:, 2]])
        min_x, max_x = np.min(x_coords), np.max(x_coords)
        min_y, max_y = np.min(y_coords), np.max(y_coords)
        # 边界框
        pred_bbox = np.array([[np.min(pred_coords[:, 1]), np.min(pred_coords[:, 2])],
                              [np.max(pred_coords[:, 1]), np.max(pred_coords[:, 2])]])
        true_bbox = np.array([[np.min(true_coords[:, 1]), np.min(true_coords[:, 2])],
                              [np.max(true_coords[:, 1]), np.max(true_coords[:, 2])]])
        # 边界框宽/高
        w_pred = pred_bbox[1, 0] - pred_bbox[0, 0] + 1
        h_pred = pred_bbox[1, 1] - pred_bbox[0, 1] + 1
        w_true = true_bbox[1, 0] - true_bbox[0, 0] + 1
        h_true = true_bbox[1, 1] - true_bbox[0, 1] + 1
        # 两框的中心距离
        center_distance = np.sqrt(((pred_bbox[0, 0] + pred_bbox[1, 0]) / 2 - (true_bbox[0, 0] + true_bbox[1, 0]) / 2) ** 2 +
                                  ((pred_bbox[0, 1] + pred_bbox[1, 1]) / 2 - (true_bbox[0, 1] + true_bbox[1, 1]) / 2) ** 2)
        # 外接矩形面积
        c = np.sqrt(((max_x - min_x) ** 2) + ((max_y - min_y) ** 2))
        
        v = (4 / (np.pi ** 2)) * (np.arctan(w_true / h_true) - np.arctan(w_pred / h_pred)) ** 2
        
        alpha = v / (1 - intersection / union + v)
        
        ciou = intersection / union - (center_distance ** 2) / c ** 2 - alpha * v
        ciou_list.append(ciou)
    
    return np.mean(ciou_list)

def dice(inputs, targets, num_objects):
    """
    计算 Dice 系数
    :param inputs: 预测的分割结果,形状为 [H, W]
    :param targets: 真实的分割标签,形状为 [H, W]
    :param num_objects: mask 个数(类别数)
    :return: Dice 系数
    """
    dice_list = []
    for i in range(1, num_objects + 1):  # 假设背景是 0,从 1 开始计算各个对象的 Dice
        pred_mask = (inputs == i)
        true_mask = (targets == i)
        intersection = np.logical_and(pred_mask, true_mask).sum()
        dice_value = (2. * intersection) / (pred_mask.sum() + true_mask.sum())
        dice_list.append(dice_value)
    return np.mean(dice_list)

def pixel_accuracy(inputs, targets, num_objects):
    """
    计算像素准确度
    :param inputs: 预测的分割结果,形状为 [H, W]
    :param targets: 真实的分割标签,形状为 [H, W]
    :param num_objects: mask 个数(类别数)(虽然这里用不到,但保持接口一致)
    :return: 像素准确度
    """
    correct = (inputs == targets).sum()
    total = inputs.size
    return correct / total

def boundary_f1_score(inputs, targets, num_objects):
    """
    计算边界 F1 分数
    :param inputs: 预测的分割结果,形状为 [H, W]
    :param targets: 真实的分割标签,形状为 [H, W]
    :param num_objects: mask 个数(类别数)
    :return: 边界 F1 分数
    """
    # 创建一个结构元素用于边界检测(这里使用 3x3 的十字结构)
    struct = np.array([[0, 1, 0],
                       [1, 1, 1],
                       [0, 1, 0]])

    f1_list = []
    for i in range(1, num_objects + 1):  # 对每个对象计算边界 F1 分数
        # 获取预测和真实的目标掩码
        pred_mask = (inputs == i).astype(np.uint8)[0]
        true_mask = (targets == i).astype(np.uint8)[0]

        # 计算边界
        pred_boundary = binary_dilation(pred_mask, struct) ^ binary_erosion(pred_mask, struct)
        true_boundary = binary_dilation(true_mask, struct) ^ binary_erosion(true_mask, struct)

        # 计算 TP、FP、FN
        tp = np.logical_and(pred_boundary, true_boundary).sum()
        fp = np.logical_and(pred_boundary, np.logical_not(true_boundary)).sum()
        fn = np.logical_and(np.logical_not(pred_boundary), true_boundary).sum()

        # 计算精确度和召回率
        if tp + fp == 0:
            precision = 0.0
        else:
            precision = tp / (tp + fp)

        if tp + fn == 0:
            recall = 0.0
        else:
            recall = tp / (tp + fn)

        # 计算 F1 分数
        if precision + recall == 0:
            f1 = 0.0
        else:
            f1 = 2 * (precision * recall) / (precision + recall)

        f1_list.append(f1)

    return np.mean(f1_list)

def estimate(inputs_list, targets_list, num_objects, metric_dict):
    for inputs, targets in zip(inputs_list, targets_list):
        for (metric_name, metric_func), score_list in metric_dict.items():
            score_list.append(metric_func(inputs, targets, num_objects))
    return metric_dict

🚀评估脚本

评估脚本需要根据chat.py进行修改,把它改成对batch数据的预测并增加评估环节。以下展示主要修改的部分,也就是batch循环里面的部分。

    model.eval()
    dataset = LabPicsDataset("/media/zigaa/leofile/leosam2/assets/LabPicsV1/Simple/Test", transform=transforms.ToTensor())
    dataloader = DataLoader(dataset, batch_size=16, shuffle=False, num_workers=4)
    
    metric_dict = {("IoU", iou): [], ("GIoU", giou): [], ("CIoU", ciou): [], ("Dice", dice): [], ("Pixel Accuracy", pixel_accuracy): [], ("Boundary F1 Score", boundary_f1_score): []}

    for batch_idx, (image, mask, input_point, input_label) in enumerate(dataloader):
        gt_masks = mask.float().unsqueeze(1)
        gt_masks = [gt_masks[i].numpy() for i in range(gt_masks.shape[0])]

        conv = conversation_lib.conv_templates[args.conv_type].copy()
        conv.messages = []

        prompt = "show me all the vessels."
        prompt = DEFAULT_IMAGE_TOKEN + "\n" + prompt
        if args.use_mm_start_end:
            replace_token = (
                DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN
            )
            prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token)

        conv.append_message(conv.roles[0], prompt)
        conv.append_message(conv.roles[1], "")
        prompt = conv.get_prompt()
        
        t1 = time.time()
        original_size_list = [image.shape[-2:] for _ in range(image.shape[0])]

        image_clip = (
            clip_image_processor.preprocess(image, return_tensors="pt")[
                "pixel_values"
            ]
            .cuda()
        )

        if args.precision == "bf16":
            image_clip = image_clip.bfloat16()
        elif args.precision == "fp16":
            image_clip = image_clip.half()
        else:
            image_clip = image_clip.float()

        # image = transform.apply_image(image)
        resize_list = [image.shape[-2:] for _ in range(image.shape[0])]

        image = image.cuda()
        if args.precision == "bf16":
            image = image.bfloat16()
        elif args.precision == "fp16":
            image = image.half()
        else:
            image = image.float()

        input_ids = tokenizer_image_token(prompt, tokenizer, return_tensors="pt")
        input_ids = input_ids.repeat(image.shape[0], 1).cuda()

        output_ids, pred_masks = model.evaluate(
            image_clip,
            image,
            input_ids,
            resize_list,
            original_size_list,
            max_new_tokens=512,
            tokenizer=tokenizer,
        )

        output_ids = [row[row != IMAGE_TOKEN_INDEX] for row in output_ids]
        text_output = [tokenizer.decode(o, skip_special_tokens=False) for o in output_ids]
        text_output = [t.replace("\n", "").replace("  ", " ") for t in text_output]

        pred_masks = [p.detach().cpu().numpy() > 0 for p in pred_masks]

        num_objects = 1
        metric_dict = estimate(pred_masks, gt_masks, num_objects, metric_dict)

    for (metric_name, _), score_list in metric_dict.items():
        print(f"Average {metric_name}:", np.mean(score_list))

给定的文本是”系统:Please finds out all the vessels in the picture.””用户:show me all the vessels.”,符合SAM2之前的训练任务,就是只分割一个mask,分割出所有容器即可。精度用的bf16,fp32会超4090D显存。

🚀评估结果

规模 IoU GIoU CIoU Dice PA F1
LISA 0.784 0.438 0.766 0.857 0.967 0.276
SAM2(l) 无微调, 无提示 0.516 0.071 0.477 0.588 0.921 0.130
SAM2(l) 无微调,有提示 0.734 0.556 0.710 0.798 0.961 0.221
SAM2(l) 微调, 无提示 0.895 0.632 0.884 0.938 0.988 0.291

可以看到在没有微调的情况下半精度的LISA的效果超过了使用提示的SAM2,和针对该数据集微调的SAM2的large模型(无提示)相比还有差距,值得注意的是这里SAM2用的提示是直接提供目标mask中的一个点,是与目标高度相关的极其有用的信息,而LISA的提示词并没有这种作用。

🔥LISA微调

为了看看LISA的下游数据集训练效果如何,是否能通过微调提升对特定小数据集的分割效果,我们可以借用train_ds.py脚本进行LoRA微调。该脚本提供了很多任务的训练,我们就做最简单的一次分割一个mask的任务即可。

🚀dataset

观察utils下的HybridDataset(管理所有监督任务的dataset)和SemSegDataset(官方其中一个监督任务的dataset),可以看到每个样本应当有十个数据。

# SemSegDataset的样本
return (image_path, image, image_clip, conversations,
        masks, label, resize, questions, sampled_classes,
)
# HybridDataset的样本,前者包含了上面这九个变量
return *data[0], inference
变量 描述 变量 描述
image_path 图像路径 label 整张图的标注(张量,每个像素属于哪一类)
image 原图像(张量) resize 变换后的图像尺寸
image_clip 预处理后的图像(各种变换) questions 问题(字符串)
conversations 对话列表 sampled_classes 分类名(字符串)列表
masks label拆分出来的若干个掩码(张量,同一类放在一个mask) inference 是否做推理任务

所以我们只需要把自己的dataset改写成这种输出形式就可以适配其训练脚本,于是我写了个LabPicsDatasetForTrain类专门用于训练,可与评估用的那个分开(如下)。

  • 这个数据集label其实就是mask了,只需要降一个维度;

  • 分类名自己设计,我们分割的是容器vessels,这个名字也就是用于生成问题的而已,这个问题是从官方预设的一些模板抽取的。

import numpy as np
import cv2
import torch
import random
from torch.utils.data import Dataset, DataLoader
import os
from transformers import CLIPImageProcessor

from utils.utils import (DEFAULT_IM_END_TOKEN, DEFAULT_IM_START_TOKEN,
                         DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX)
from model.llava import conversation as conversation_lib
from utils.utils import ANSWER_LIST, SHORT_QUESTION_LIST

class LabPicsDatasetForTrain(Dataset):
    def __init__(self, data_dir, tokenizer, vision_tower, transform=None):
        self.data_dir = data_dir
        self.transform = transform
        self.data = self._load_data()
        self.tokenizer = tokenizer
        self.clip_image_processor = CLIPImageProcessor.from_pretrained(vision_tower)
        self.short_question_list = SHORT_QUESTION_LIST
        self.answer_list = ANSWER_LIST

    def _load_data(self):
        data = []
        image_folder = os.path.join(self.data_dir, "Image")
        annotation_folder = os.path.join(self.data_dir, "Instance")
        for name in os.listdir(image_folder):
            image_path = os.path.join(image_folder, name)
            annotation_path = os.path.join(annotation_folder, name[:-4] + ".png")
            data.append({"image": image_path, "annotation": annotation_path})
        return data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        entry = self.data[idx]
        image = cv2.imread(entry["image"])[..., ::-1]  # 读取图像并转换为 RGB 格式
        annotation = cv2.imread(entry["annotation"])  # 读取标注

        image_clip = self.clip_image_processor.preprocess(
                image, return_tensors="pt"
            )["pixel_values"][0]
        # 调整图像大小
        r = np.min([1024 / image.shape[1], 1024 / image.shape[0]])  # 缩放因子
        image = cv2.resize(image, (int(image.shape[1] * r), int(image.shape[0] * r)))
        annotation = cv2.resize(annotation, (int(annotation.shape[1] * r), int(annotation.shape[0] * r)), interpolation=cv2.INTER_NEAREST)

        if image.shape[0] < 1024:
            image = np.concatenate([image, np.zeros([1024 - image.shape[0], image.shape[1], 3], dtype=np.uint8)], axis=0)
            annotation = np.concatenate([annotation, np.zeros([1024 - annotation.shape[0], annotation.shape[1], 3], dtype=np.uint8)], axis=0)
        if image.shape[1] < 1024:
            image = np.concatenate([image, np.zeros([image.shape[0], 1024 - image.shape[1], 3], dtype=np.uint8)], axis=1)
            annotation = np.concatenate([annotation, np.zeros([annotation.shape[0], 1024 - annotation.shape[1], 3], dtype=np.uint8)], axis=1)

        # 合并材料和容器标注
        mat_map = annotation[:, :, 0]
        ves_map = annotation[:, :, 2]
        mat_map[mat_map == 0] = ves_map[mat_map == 0] * (mat_map.max() + 1)

        # 获取二值掩码和点
        inds = np.unique(mat_map)[1:]
        if len(inds) > 0:
            ind = inds[np.random.randint(len(inds))]
            # mask = (mat_map == ind).astype(np.uint8)
            mask = (mat_map > 0).astype(np.uint8)           # 全部mask合成一个(全部预测出来)
            coords = np.argwhere(mask > 0)
            yx = coords[np.random.randint(len(coords))]
            point = [[yx[1], yx[0]]]
        else:
            # 如果没有有效标注,返回全零掩码和随机点
            mask = np.zeros((image.shape[1:]), dtype=np.uint8)
            point = [[np.random.randint(0, 1024), np.random.randint(0, 1024)]]
        mask = torch.from_numpy(mask).float().unsqueeze(0)

        if self.transform:
            image = self.transform(image)
        resize = image.shape[1:]

        # 文本
        questions = []
        answers = []
        class_ids = [1]
        sampled_classes = ["vessels"]
        for sampled_cls in sampled_classes:
            text = sampled_cls

            assert len(text.split("||")) == 1
            question_template = random.choice(self.short_question_list)
            questions.append(question_template.format(class_name=text.lower()))

            answers.append(random.choice(self.answer_list))

        conversations = []
        conv = conversation_lib.default_conversation.copy()

        i = 0
        while i < len(questions):
            conv.messages = []
            conv.append_message(conv.roles[0], questions[i])
            conv.append_message(conv.roles[1], answers[i])
            conversations.append(conv.get_prompt())
            i += 1

        return entry, image, image_clip, conversations, mask, mask.squeeze(0), resize, questions, sampled_classes, False

🚀训练脚本

有了这个只需要把train_ds.py的HybridDataset改成LabPicsDatasetForTrain(233行附近)就可以训练了,一些需要改的超参数如下表(显卡桌面4090D),其它参数包括损失权重、lora训练的参数就不用太去改了。

    train_dataset = LabPicsDatasetForTrain("/media/zigaa/leofile/leosam2/assets/LabPicsV1/Simple/Train", 
                                           tokenizer=tokenizer,
                                           vision_tower=args.vision_tower,
                                           transform=transforms.ToTensor())
参数名 描述 参数值
version 预训练LLaVA权重路径,可以是LISA完整权重 ./weight/lisa
precision 精度(前向传播等操作) bf16
image_size 图像尺寸 1024
vision-tower LLaVA视觉主干权重路径 ./weight/clip-vit-large-patch14
log_base_dir 记录每次训练用的文件夹,每次训练的tensorboard和权重都在里面 ./runs
exp_name 这次训练的名称,log_base_dir下生成该名称的文件夹 lisa
epochs 迭代数 10
steps_per_epoch 每一代多少步 1500
batch_size 批次大小(乘以steps_per_epoch就是一代采样数),显存不够只能1 1
grad_accumulation_steps 梯度积累 20
lr 学习率(可以大点,有学习率下降调度器) 0.0008
ce_loss_weight 文本交叉熵损失权重(这次主要考虑图像) 0.2
no_eval 不做验证(训练不久没必要做) True
vision_pretrained SAM预训练权重路径 ./weight/sam_vit_h_4b8939.pth

训练脚本很长,主要因为是没有做模块化设计的(一个老长的main加两三个函数就搞定了),是一个LoRA微调VLLM——特别还是能输出图像的VLLM的一个不错的范本。

  • main函数主要流程就是:
  1. 加载超参数,初始化模型(分词器、LISA、LLaVA视觉主干),配置LoRA参数;

  2. 加载数据集;

  3. 配置deepspeed参数并初始化;

  4. 训练循环(训练、验证、保存)。

  • 训练函数和一般深度学习模型训练一样,用的损失是MaskBCELoss(像素级交叉熵)、MaskDICELoss(和IoU差不多)、CELoss(文本的交叉熵),前两个加权合成MaskLoss。

  • 验证函数也不复杂,用的是GIOU和CIOU评估指标,而不是损失。

🚀微调

数据集接入了训练脚本,改完了超参数,就可以训练,画面如下(前面还会加载一堆东西),中间可以用tensorboard看损失、评估等曲线。

训练界面

获得微调权重,训练完成后在终端跑以下指令,就会在./runs/lisa下生成pytorch_model.bin文件夹,里面是完整的LISA权重,里面既有原模型冻结的部分,也有LoRA权重。

cd ./runs/lisa/ckpt_model && python zero_to_fp32.py . ../pytorch_model.bin

成功会输出如下内容,./runs/lisa/pytorch_model.bin里面就有bin类型的分片权重文件。

[2025-05-21 11:12:18,671] [INFO] [real_accelerator.py:239:get_accelerator] Setting ds_accelerator to cuda (auto detect)
Processing zero checkpoint './global_step300'
Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 98.66it/s]
Detected checkpoint of type zero stage 2, world_size: 1
Parsing checkpoint created by deepspeed==0.16.7
Reconstructed Frozen fp32 state dict with 1155 params 7420682060 elements
Reconstructed fp32 state dict with 254 params 288259556 elements
Saving checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:13<00:00,  3.46s/it]

合并权重merge_lora_weights_and_save_hf_model.py脚本提供了将生成的bin权重和原模型合并的功能,参数要改的如下。

  --version="./weight/lisa" \
  --weight="./runs/lisa/pytorch_model.bin" \
  --save_path="./weight/lisa_model_finetuned"

但是脚本里面有一些问题,如下,需要把脚本150行附近的torch.load函数(注释部分)改成前面这部分,也就要逐步读取4个分片文件后放到同一个state_dict里面,再进行合并。运行完成没报错就行。在save_path路径下会生成Hugging Face格式的权重(和HF下载的权重一样)

    state_dict = torch.nn.Module().state_dict()
    # 遍历所有分割的权重文件
    for i in range(1, 5):
        file_path = f'{args.weight}/pytorch_model-0000{i}-of-00004.bin'
        state_dict_part = torch.load(file_path, map_location=torch.device('cpu'))
        state_dict.update(state_dict_part)
    # state_dict = torch.load(args.weight, map_location="cpu")

🚀评估

规模 IoU GIoU CIoU Dice PA F1
LISA-7B 0.784 0.438 0.766 0.857 0.967 0.276
LISA-7B 微调 0.871 0.580 0.857 0.922 0.984 0.300
SAM2(l) 微调, 无提示 0.895 0.632 0.884 0.938 0.988 0.291

可以看到微调后有很明显的提升,和SAM2差距也不大,这个也暂时不是最优的训练结果。

🔥LISA++预测

LISA++项目整体结构不变,就多了个chat_instance.py用于预测,没有提供新的app.py,配置方法和LISA一样。

LISAplus对话效果

LISAplus对话效果2

where is the vessel and the liquid? output the segmentation of them.

LISA++对多目标需求的理解能力和实际分割效果都比LISA要强一些。

但是似乎提供的权重似乎并不能像论文中的那样完整描述图片的同时输出[SEG],并且暂时没有找到如何根据一张图进行多轮对话,目前看来可能是生成数据集时用了这个方法,要多轮对话估计得在chat_instance.py改。

  • Copyrights © 2023-2025 LegendLeo Chen
  • 访问人数: | 浏览次数:

请我喝杯咖啡吧~

支付宝
微信