Reasoning RF-DETR:基于多模态推理的开放世界图像分割系统(附完整代码)

一个结合 MLLM(多模态大语言模型)、Grounding DINO 和 SAM 自动构建数据集,并训练 RF-DETR 端到端分割模型的完整闭环项目。


目录


一、项目简介

1.1 这是什么项目

本项目实现了一条**“零代码全自动打标 → 端到端带空间感知的分割模型训练 → 工业级推理”**的完整闭环流水线。

用户只需输入一句自然语言指令(如"分割出所有可以移动的软装家具"),系统就能完成以下工作:调用大模型理解指令语义,自动在图像中定位目标,生成像素级分割掩码,并训练一个轻量级端到端模型来完成同样的任务。

注意:当前模型为简易测试版本,用于快速验证流水线可行性。 模型采用基础的 ResNet50 + BERT + Transformer 架构,文本注入方式为简单的 mean-pooling 加法,Mask Head 使用单尺度特征图点积。该版本在少量数据上可跑通全流程,但分割精度和泛化能力有限。如需用于实际场景,建议参考第十一节中的改进方向进行优化。

1.2 解决什么问题

传统图像分割模型需要大量人工标注的像素级Mask,标注成本极高。本项目通过 MLLM + Grounding DINO + SAM 三级流水线实现零人工标注,大幅降低数据构建成本。

1.3 应用场景

  • 候诊室和办公室软装识别与一键消除
  • 工地危险物检测与分割
  • 桥梁和机械结构缺陷定位
  • 室内设计辅助(光照区域分析、家具布局规划)

二、核心创新点

创新点一:三级自动打标流水线。 MLLM负责语义理解,Grounding DINO负责开放世界定位,SAM负责像素级分割,三者串联实现零人工标注。

创新点二:推理式文本引导分割。 不是简单的"分割猫",而是接受复杂推理指令,例如"找出下午会被阳光直射的地板区域"。

创新点三:轻量级RF-DETR架构。 使用ResNet50作为视觉骨干、BERT作为文本编码器、Transformer进行跨模态融合,冻结骨干网络后单卡即可训练。

创新点四:2D空间位置编码。 注入行列位置信息,让Transformer具备空间方向感。

创新点五:匈牙利二分图匹配。 解决50个Query无序输出与真实标签之间的梯度分配问题。


三、系统架构总览

整个系统分为两个解耦的阶段。

阶段一:数据自动构建(离线预处理)

原始图片和中文推理指令首先送入MLLM API(GPT-4o),大模型理解指令语义后输出英文实体词,例如输入"分割出所有可以移动的软装家具",输出"chairs, coffee table, plant"。这些英文实体词送入Grounding DINO,它在图像上预测出多个Bounding Box,将沙发、茶几、绿植等目标用矩形框锁定。这些Box坐标送入SAM,SAM沿着物体的实际边缘切出像素级的高清Mask。最后,脚本自动将原图、中文原指令、合并后的高清Mask、所有Bbox坐标打包写入JSON文件,形成可训练的数据集。

阶段二:端到端模型训练与推理

AdvancedRFDETR模型接收图像和文本指令作为输入。ResNet50提取视觉特征并注入2D空间位置编码,BERT提取文本语义特征并与Transformer的Object Queries融合,Transformer编解码器计算跨模态注意力,Mask Head将解码输出与图像特征图进行点积生成最终Mask。训练时使用匈牙利匹配器找到与真实标签最匹配的Query,对该Query计算分类损失和分割损失,对其余Query计算背景惩罚损失。推理时取置信度最高的Query对应的Mask作为最终输出。


四、项目目录结构

Reasoning_RFDETR/
│
├── weights/                              # 存放基础模型权重
│   └── sam_vit_h_4b8939.pth              # (必须手动下载) SAM的视觉大模型权重
│
├── models/                               # 存放自动下载的预训练模型
│   ├── grounding-dino-base/              # Grounding DINO权重
│   └── bert-base-chinese/                # BERT中文权重
│
├── raw_waiting_room_photos/              # 【输入】原始待标注图片
│   └── *.jpg                             # 60张候诊室照片
│
├── waiting_room_reasoning_dataset/       # 【输出】自动生成的真实数据集
│   ├── images/                           # 去重后的原图
│   ├── masks/                            # 合并后的像素级Mask
│   └── annotations.json                  # 标注文件
│
├── data/                                 # 玩具数据集(调试用)
│   ├── images/
│   ├── masks/
│   └── annotations.json
│
├── auto_labeling_pipeline.py             # 【核心】自动打标流水线
├── train_real_data.py                    # 【核心】真实数据训练脚本
├── inference_real_data.py                # 【核心】真实数据推理脚本
├── loss.py                               # 损失函数与匈牙利匹配器
├── model.py                              # 早期版本模型定义
├── dataset.py                            # 玩具数据集Dataset类
├── train.py                              # 玩具数据训练脚本
├── inference.py                          # 玩具数据推理脚本
├── generate_toy_data.py                  # 玩具数据生成脚本
├── test_dataloader.py                    # DataLoader调试脚本
│
├── reasoning_rfdetr_best.pth             # 玩具数据训练权重 (~528MB)
├── advanced_rfdetr_waiting_room.pth      # 真实数据训练权重 (~528MB)
├── inference_result.png                  # 玩具数据推理结果
├── inference_real_result.png             # 真实数据推理结果
├── dataloader_test_result.png            # DataLoader测试可视化
│
├── README.MD                             # 部署使用手册
├── PROJECT_GUIDE.md                      # 项目详尽说明
└── 代码介绍.txt                           # 开发历史记录

五、各脚本详细说明与完整代码

5.1 generate_toy_data.py —— 玩具数据生成器

角色: 调试辅助脚本。

功能: 用OpenCV自动生成几何图形作为假图,同时生成对应的白底黑块Mask,并配上中文推理文字。随机生成红、绿、蓝三种颜色的矩形,在图片上画实心矩形,在Mask上同一位置设为白色(代表目标),背景为黑色。为每张图生成对应的中文描述文本,如"画面中那个红色的方形物体"。所有数据保存到 data/ 目录下。

输入: 无,自动随机生成。

输出: data/images/ 目录下的JPG图片、data/masks/ 目录下的PNG掩码、data/annotations.json 标注文件。

运行命令: python generate_toy_data.py

完整代码:

# 自动生成测试数据的脚本
import os
import cv2
import json
import numpy as np


def create_toy_dataset(num_samples=10):
    # 创建目录
    os.makedirs("data/images", exist_ok=True)
    os.makedirs("data/masks", exist_ok=True)

    annotations = []

    print("开始生成 Toy Dataset...")
    for i in range(num_samples):
        # 1. 创建一张黑色背景的图片 (512x512)
        img = np.zeros((512, 512, 3), dtype=np.uint8)
        mask = np.zeros((512, 512), dtype=np.uint8)  # Mask是单通道灰度图

        # 随机生成一个矩形的位置
        x1, y1 = np.random.randint(50, 200, 2)
        w, h = np.random.randint(100, 200, 2)
        x2, y2 = x1 + w, y1 + h

        # 随机决定颜色 (红、绿、蓝) 和对应的文字描述
        color_choice = np.random.choice(['红色', '绿色', '蓝色'])
        color_bgr = (0, 0, 255) if color_choice == '红色' else \
            (0, 255, 0) if color_choice == '绿色' else \
                (255, 0, 0)

        # 在图片上画实心矩形
        cv2.rectangle(img, (x1, y1), (x2, y2), color_bgr, -1)
        # 在Mask上,该区域设为255(白色,代表目标),背景为0(黑色)
        cv2.rectangle(mask, (x1, y1), (x2, y2), 255, -1)

        # 文件路径
        img_name = f"image_{i:03d}.jpg"
        mask_name = f"mask_{i:03d}.png"
        img_path = os.path.join("data/images", img_name)
        mask_path = os.path.join("data/masks", mask_name)

        # 保存图片和Mask
        cv2.imwrite(img_path, img)
        cv2.imwrite(mask_path, mask)

        # 模拟需要推理的文本(实际任务中,这里是MLLM输出的描述)
        reasoning_text = f"画面中那个{color_choice}的方形物体"

        annotations.append({
            "image_path": img_path,
            "mask_path": mask_path,
            "text": reasoning_text
        })

    # 保存 JSON 标注文件
    with open("data/annotations.json", "w", encoding="utf-8") as f:
        json.dump(annotations, f, ensure_ascii=False, indent=4)

    print(f"成功生成 {num_samples} 条数据!保存在 data/ 目录下。")


if __name__ == "__main__":
    create_toy_dataset(20)  # 生成20张图用来测试

5.2 dataset.py —— 玩具数据集Dataset类

角色: 数据管道,被 train.pyinference.pytest_dataloader.py 导入使用。

功能: 定义 ReasoningSegDataset 类,继承自 torch.utils.data.Dataset。读取 JSON 标注文件,对图像进行 Resize到512、ToTensor转换、ImageNet标准归一化(均值[0.485,0.456,0.406],标准差[0.229,0.224,0.225])。对Mask进行 Resize到512(使用NEAREST插值保持二值性)、ToTensor转换、二值化(大于0.5的设为1,其余为0)。每次调用返回图像Tensor、Mask Tensor和文本字符串。

核心类: ReasoningSegDataset

完整代码:

#要把硬盘里的图片和文本喂给显卡,必须经过 torch.utils.data.Dataset。
import json
import torch
from torch.utils.data import Dataset
from PIL import Image
import torchvision.transforms as T


class ReasoningSegDataset(Dataset):
    def __init__(self, json_path, img_size=512):
        """
        json_path: 标注文件的路径
        img_size: 统一缩放的尺寸
        """
        with open(json_path, 'r', encoding='utf-8') as f:
            self.annotations = json.load(f)

        self.img_size = img_size

        # 图片的预处理:缩放 -> 转为Tensor -> 归一化 (ImageNet标准)
        self.img_transform = T.Compose([
            T.Resize((img_size, img_size)),
            T.ToTensor(),
            T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])

        # Mask的预处理:缩放 -> 转为Tensor (不需要归一化,只有0和1)
        self.mask_transform = T.Compose([
            T.Resize((img_size, img_size), interpolation=T.InterpolationMode.NEAREST),
            T.ToTensor()
        ])

    def __len__(self):
        return len(self.annotations)

    def __getitem__(self, idx):
        anno = self.annotations[idx]

        # 1. 读取原图 (RGB)
        img_path = anno['image_path']
        image = Image.open(img_path).convert('RGB')

        # 2. 读取 Mask (L 表示灰度图)
        mask_path = anno['mask_path']
        mask = Image.open(mask_path).convert('L')

        # 3. 读取文本
        text = anno['text']

        # 4. 应用变换
        image_tensor = self.img_transform(image)
        mask_tensor = self.mask_transform(mask)

        # 确保 mask_tensor 的值是 0 或 1 (原本是0或255,ToTensor会将其转为0-1)
        mask_tensor = (mask_tensor > 0.5).float()

        return image_tensor, mask_tensor, text

5.3 test_dataloader.py —— DataLoader调试脚本

角色: 调试验证脚本。

功能: 实例化Dataset和DataLoader,抓取一个batch的数据,打印图像和Mask的维度信息(期望图像为[2,3,512,512],Mask为[2,1,512,512]),打印文本内容。将第一张图逆归一化后与Mask并排可视化,保存为图片。目的是在进入训练之前确认数据管道完全正确,防止黑盒报错。

输出: dataloader_test_result.png

运行命令: python test_dataloader.py

完整代码:

#把数据抽出来打印一下维度,并且画出来看看。(防止黑盒报错)。
import numpy as np
import torch
from torch.utils.data import DataLoader
from dataset import ReasoningSegDataset
import matplotlib.pyplot as plt


def test_pipeline():
    # 1. 实例化 Dataset
    dataset = ReasoningSegDataset(json_path="data/annotations.json", img_size=512)

    # 2. 实例化 DataLoader (batch_size=2 意味着每次吐出2张图)
    dataloader = DataLoader(dataset, batch_size=2, shuffle=True)

    print(f"数据集总大小: {len(dataset)}")

    # 3. 抓取一个 Batch 的数据
    images, masks, texts = next(iter(dataloader))

    print("\n--- DataLoader 吐出的张量维度 ---")
    print(f"Images Shape: {images.shape}")  # 期望: [2, 3, 512, 512]
    print(f"Masks Shape:  {masks.shape}")  # 期望: [2, 1, 512, 512]
    print(f"Texts:        {texts}")  # 期望: 元组形式的两个字符串

    # 4. 逆归一化并可视化 (选第一张图)
    img_show = images[0].permute(1, 2, 0).numpy()
    # ImageNet 逆归一化公式
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    img_show = std * img_show + mean
    img_show = np.clip(img_show, 0, 1)

    mask_show = masks[0].squeeze().numpy()

    # 使用 matplotlib 画图
    plt.figure(figsize=(10, 5))

    plt.subplot(1, 2, 1)
    plt.title("Original Image (Un-normalized)")
    plt.imshow(img_show)
    plt.axis('off')

    plt.subplot(1, 2, 2)
    plt.title("Ground Truth Mask")
    plt.imshow(mask_show, cmap='gray')
    plt.axis('off')

    plt.suptitle(f"Text Prompt: {texts[0]}", fontsize=14)
    plt.tight_layout()

    # 保存可视化结果而不是弹窗(防止服务器环境下报错)
    plt.savefig("dataloader_test_result.png")
    print("\n可视化结果已保存为 dataloader_test_result.png")


if __name__ == "__main__":
    test_pipeline()

5.4 loss.py —— 损失函数与匹配器

角色: 核心算法模块,解决RF-DETR架构中50个无序Query输出与真实标签之间的梯度计算问题。

包含三个核心组件:

组件一:HungarianMatcher(匈牙利匹配器)。 模型输出50个预测结果,但真实标签通常只有1个目标。匹配器首先计算每个预测与真实目标之间的代价矩阵,代价由三部分组成:类别代价(预测类别与真实类别的负对数概率)、BCE代价(预测Mask与真实Mask的二值交叉熵)、Dice代价(预测Mask与真实Mask的Dice系数)。然后使用scipy的linear_sum_assignment算法找出代价最小的那个Query作为正样本匹配,其余49个Query标记为背景。匹配过程在torch.no_grad下进行,不参与梯度计算。

组件二:SetCriterion(集合损失准则)。 基于匹配器的结果计算最终损失。对匹配成功的Query计算分类交叉熵损失(CE Loss)、掩码二值交叉熵损失(BCE Loss)和Dice损失。对未匹配的49个Query计算背景惩罚损失,背景类权重设为0.1以降低惩罚力度,防止网络将所有Query都预测为背景。损失权重比例为 loss_ce 权重1.0、loss_mask 权重5.0、loss_dice 权重2.0。

组件三:辅助损失函数。 dice_loss 专门用于优化分割掩码的边缘质量,计算预测Mask和真实Mask的Dice系数(2倍交集除以并集),损失值为1减去Dice系数。sigmoid_focal_loss 用于处理类别极度不平衡问题,通过降低易分类样本的权重让模型聚焦于难分类样本,该函数已实现但尚未在训练中启用。

完整代码:

import torch
import torch.nn as nn
import torch.nn.functional as F
from scipy.optimize import linear_sum_assignment


def dice_loss(inputs, targets, num_boxes):
    """
    计算 Dice Loss,专门用于优化分割掩码
    inputs: 预测的mask [N, H, W] (未经过sigmoid)
    targets: 真实的mask [N, H, W] (0或1)
    """
    inputs = inputs.sigmoid()
    inputs = inputs.flatten(1)  # 展平为 [N, H*W]
    targets = targets.flatten(1)

    numerator = 2 * (inputs * targets).sum(1)
    denominator = inputs.sum(-1) + targets.sum(-1)
    loss = 1 - (numerator + 1) / (denominator + 1)  # +1 是平滑项,防止除以0
    return loss.sum() / num_boxes


def sigmoid_focal_loss(inputs, targets, num_boxes, alpha: float = 0.25, gamma: float = 2):
    """用于分类的 Focal Loss,解决背景太多导致的正负样本极度不平衡"""
    prob = inputs.sigmoid()
    ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
    p_t = prob * targets + (1 - prob) * (1 - targets)
    loss = ce_loss * ((1 - p_t) ** gamma)

    if alpha >= 0:
        alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
        loss = alpha_t * loss

    return loss.mean(1).sum() / num_boxes


class HungarianMatcher(nn.Module):
    """
    匈牙利匹配器:用于在 50 个预测结果中,找到与真实 Mask 最匹配的那一个
    """

    def __init__(self, cost_class: float = 1.0, cost_mask: float = 1.0, cost_dice: float = 1.0):
        super().__init__()
        self.cost_class = cost_class
        self.cost_mask = cost_mask
        self.cost_dice = cost_dice

    @torch.no_grad()
    def forward(self, outputs, targets):
        bs, num_queries = outputs["pred_logits"].shape[:2]

        # 获取预测的类别和Mask
        out_prob = outputs["pred_logits"].flatten(0, 1).softmax(-1)  # [batch_size * num_queries, num_classes]
        out_mask = outputs["pred_masks"].flatten(0, 1)  # [batch_size * num_queries, H, W]

        # 获取真实的类别和Mask (这里假定每张图只有1个目标,类别ID为1)
        tgt_ids = torch.cat([v["labels"] for v in targets])  # [num_total_targets]
        tgt_mask = torch.cat([v["masks"] for v in targets])  # [num_total_targets, H, W]

        # 1. 计算类别代价 (Class Cost)
        cost_class = -out_prob[:, tgt_ids]

        # 2. 计算 Mask 代价 (BCE Cost)
        # 为了加速计算,对 Mask 进行下采样降维
        out_mask_down = F.interpolate(out_mask.unsqueeze(1), size=(128, 128), mode="bilinear",
                                      align_corners=False).squeeze(1)
        tgt_mask_down = F.interpolate(tgt_mask.unsqueeze(1).float(), size=(128, 128), mode="nearest").squeeze(1)

        out_mask_flat = out_mask_down.flatten(1)  # [batch_size * num_queries, 128*128]
        tgt_mask_flat = tgt_mask_down.flatten(1)  # [num_total_targets, 128*128]

        # BCE 代价 (使用 pos_weight 处理不平衡)
        cost_mask = F.binary_cross_entropy_with_logits(
            out_mask_flat.unsqueeze(1).expand(-1, tgt_mask_flat.shape[0], -1),
            tgt_mask_flat.unsqueeze(0).expand(out_mask_flat.shape[0], -1, -1),
            reduction="none"
        ).mean(-1)

        # 3. 计算 Dice 代价
        out_mask_sig = out_mask_flat.sigmoid()
        numerator = 2 * torch.matmul(out_mask_sig, tgt_mask_flat.t())
        denominator = out_mask_sig.sum(-1).unsqueeze(1) + tgt_mask_flat.sum(-1).unsqueeze(0)
        cost_dice = 1 - (numerator + 1) / (denominator + 1)

        # 4. 汇总总代价矩阵
        C = self.cost_class * cost_class + self.cost_mask * cost_mask + self.cost_dice * cost_dice
        C = C.view(bs, num_queries, -1).cpu()

        # 5. 使用匈牙利算法找出最优匹配
        sizes = [len(v["masks"]) for v in targets]
        indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(sizes, -1))]

        return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices]


class SetCriterion(nn.Module):
    """
    计算最终损失的类
    基于匹配器的结果,计算命中目标的 Loss,和其余背景的 Loss
    """

    def __init__(self, matcher):
        super().__init__()
        self.matcher = matcher
        # 类别0是背景,类别1是前景目标
        empty_weight = torch.ones(2)
        empty_weight[0] = 0.1  # 降低背景类的惩罚权重,防止网络全预测成背景
        self.register_buffer("empty_weight", empty_weight)

    def _get_src_permutation_idx(self, indices):
        batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
        src_idx = torch.cat([src for (src, _) in indices])
        return batch_idx, src_idx

    def forward(self, outputs, targets):
        """
        outputs: 模型的输出字典 (pred_logits, pred_masks)
        targets: 真实标签列表,例如 [{'labels': tensor([1]), 'masks': tensor([1, H, W])}, ...]
        """
        # 1. 获取匹配索引
        indices = self.matcher(outputs, targets)
        idx = self._get_src_permutation_idx(indices)

        # 2. 计算分类损失 (Cross Entropy)
        pred_logits = outputs["pred_logits"]
        target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)])
        target_classes = torch.full(pred_logits.shape[:2], 0, dtype=torch.int64, device=pred_logits.device)  # 默认全是背景(0)
        target_classes[idx] = target_classes_o  # 把匹配到的位置标为前景(1)

        loss_ce = F.cross_entropy(pred_logits.transpose(1, 2), target_classes, self.empty_weight)

        # 3. 计算 Mask 损失 (仅对匹配到的 Query 计算)
        src_masks = outputs["pred_masks"][idx]  # 提取命中的预测 Mask
        target_masks = torch.cat([t["masks"][i] for t, (_, i) in zip(targets, indices)]).to(
            src_masks.device)  # 提取真实的 Mask

        num_boxes = src_masks.shape[0]
        if num_boxes == 0:
            loss_mask = src_masks.sum() * 0  # 防止报错
            loss_dice = src_masks.sum() * 0
        else:
            # 缩放真实Mask到特征图大小,以对齐计算
            target_masks = F.interpolate(target_masks.unsqueeze(1).float(), size=src_masks.shape[-2:],
                                         mode="nearest").squeeze(1)

            loss_mask = F.binary_cross_entropy_with_logits(src_masks, target_masks, reduction="mean")
            loss_dice = dice_loss(src_masks, target_masks, num_boxes)

        losses = {
            "loss_ce": loss_ce,
            "loss_mask": loss_mask,
            "loss_dice": loss_dice
        }
        return losses

5.5 model.py —— 早期版本模型定义

角色: 早期模型版本,用于玩具数据验证。

核心类: RFDETRSegmentation

架构说明: 视觉部分使用ResNet50作为骨干网络,去掉最后两层(全局池化和全连接层),输出2048通道的特征图,通过1x1卷积投影到256维。文本部分使用bert-base-chinese的Tokenizer和Encoder,将文本编码为768维向量,通过线性层投影到256维,然后对所有token的向量做mean-pooling得到全局文本向量。Transformer使用PyTorch内置的nn.Transformer,编码器和解码器各4层,8个注意力头,batch_first模式。Query Embedding是一个可学习的嵌入层,50个Query各256维,加上全局文本向量后作为Transformer解码器的输入。三个预测头:class_head输出2类(背景和前景),bbox_head输出4个坐标值(经过sigmoid归一化到0-1),mask_proj将解码输出投影后与图像特征图做点积生成Mask,最后上采样回原图尺寸。

文本注入方式: BERT全局向量通过mean-pooling后直接加到所有50个Query上,所有Query获得相同的文本信息。

train_real_data.py中AdvancedRFDETR的区别: 无2D空间位置编码,文本注入方式更简单。

完整代码:

import os
import warnings
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
os.environ["HF_HUB_DISABLE_SYMLINKS_WARNING"] = "1"
warnings.filterwarnings("ignore")

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
from transformers import AutoTokenizer, AutoModel
import transformers

transformers.logging.set_verbosity_error()

class RFDETRSegmentation(nn.Module):
    def __init__(self, num_queries=50, hidden_dim=256):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.num_queries = num_queries

        # 1. 文本编码器
        self.tokenizer = AutoTokenizer.from_pretrained("bert-base-chinese")
        self.text_encoder = AutoModel.from_pretrained("bert-base-chinese")
        self.text_proj = nn.Linear(768, hidden_dim)

        # 2. 图像骨干网络 (ResNet50)
        resnet = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
        self.backbone = nn.Sequential(*list(resnet.children())[:-2])
        self.conv_proj = nn.Conv2d(2048, hidden_dim, 1)

        # 3. Transformer
        self.query_embed = nn.Embedding(num_queries, hidden_dim)
        self.transformer = nn.Transformer(
            d_model=hidden_dim, nhead=8, num_encoder_layers=4, num_decoder_layers=4, batch_first=True
        )

        # 4. 预测头
        self.class_head = nn.Linear(hidden_dim, 2)
        self.bbox_head = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, 4)
        )

        # 5. Mask Head
        self.mask_proj = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim)
        )

    def forward(self, images, reasoning_texts):
        B = images.shape[0]
        device = images.device

        text_inputs = self.tokenizer(reasoning_texts, return_tensors="pt", padding=True, truncation=True, max_length=64).to(device)
        text_features = self.text_encoder(**text_inputs).last_hidden_state
        text_features = self.text_proj(text_features)
        text_global = text_features.mean(dim=1).unsqueeze(1)

        img_features = self.backbone(images)
        img_features = self.conv_proj(img_features)
        img_seq = img_features.flatten(2).permute(0, 2, 1)

        queries = self.query_embed.weight.unsqueeze(0).repeat(B, 1, 1)
        queries = queries + text_global

        hs = self.transformer(src=img_seq, tgt=queries)

        outputs_class = self.class_head(hs)
        outputs_coord = self.bbox_head(hs).sigmoid()

        mask_embeddings = self.mask_proj(hs)
        outputs_mask = torch.einsum("bnc,bchw->bnhw", mask_embeddings, img_features)
        outputs_mask = F.interpolate(outputs_mask, size=images.shape[-2:], mode="bilinear", align_corners=False)

        return {
            "pred_logits": outputs_class,
            "pred_boxes": outputs_coord,
            "pred_masks": outputs_mask
        }

    def freeze_backbones(self):
        for param in self.backbone.parameters():
            param.requires_grad = False
        for param in self.text_encoder.parameters():
            param.requires_grad = False
        print("已冻结 ResNet50 和 BERT 的参数。")

5.6 train.py —— 玩具数据训练脚本

角色: 玩具数据训练闭环。

功能: 加载 data/annotations.json 玩具数据集,实例化 RFDETRSegmentation 模型,冻结ResNet50和BERT骨干网络以节省显存和加速训练。使用HungarianMatcher和SetCriterion计算损失,AdamW优化器(学习率1e-4),训练20个epoch,batch_size为2,图像尺寸512。每个step打印分类损失、Mask损失和Dice损失的详细数值。训练结束后保存模型权重。

超参数: epochs=20,batch_size=2,lr=1e-4,img_size=512。

损失权重: loss_ce权重1.0,loss_mask权重5.0,loss_dice权重2.0。

输出: reasoning_rfdetr_best.pth

运行命令: python train.py

完整代码:

import torch
from torch.utils.data import DataLoader
import torch.optim as optim

from dataset import ReasoningSegDataset
from model import RFDETRSegmentation
from loss import HungarianMatcher, SetCriterion


def train():
    # 1. 基础设置
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"开始训练!使用设备: {device}")

    epochs = 20
    batch_size = 2  # 显存如果不够可以改为1

    # 2. 加载数据
    dataset = ReasoningSegDataset(json_path="data/annotations.json", img_size=512)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    print(f"数据加载完成,共有 {len(dataset)} 个样本。")

    # 3. 初始化模型
    model = RFDETRSegmentation(num_queries=50).to(device)
    model.freeze_backbones()  # 冻结骨干,加速训练
    model.train()

    # 4. 初始化匹配器与损失函数
    matcher = HungarianMatcher()
    criterion = SetCriterion(matcher).to(device)  # 在这里加上 .to(device)


    # 5. 初始化优化器 (AdamW对Transformer最好)
    # 过滤掉被冻结的参数,只优化需要梯度的参数
    optimizer = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-4)

    # 6. 开始 Epoch 循环
    for epoch in range(epochs):
        epoch_loss = 0.0

        for step, (images, masks, texts) in enumerate(dataloader):
            images = images.to(device)
            masks = masks.to(device)  # shape: [B, 1, 512, 512]

            # --- A. 组装 Targets 格式 ---
            # SetCriterion 要求的格式是 List[Dict]
            targets = []
            for i in range(images.shape[0]):
                targets.append({
                    "labels": torch.tensor([1], dtype=torch.int64, device=device),  # 前景类别固定为1
                    "masks": masks[i]  # [1, 512, 512]
                })

            # --- B. 前向传播 ---
            outputs = model(images, texts)

            # --- C. 计算损失 ---
            loss_dict = criterion(outputs, targets)

            # 损失加权 (这是DETR官方推荐的权重比例)
            weight_dict = {"loss_ce": 1.0, "loss_mask": 5.0, "loss_dice": 2.0}

            loss = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys())

            # --- D. 反向传播与优化 ---
            optimizer.zero_grad()
            loss.backward()

            # 梯度裁剪 (防止Transformer训练初期梯度爆炸)
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.1)

            optimizer.step()

            epoch_loss += loss.item()

            # 打印每个 Step 的 Loss 细节
            if step % 5 == 0:
                print(f"Epoch [{epoch + 1}/{epochs}] Step [{step}/{len(dataloader)}] "
                      f"Total Loss: {loss.item():.4f} | "
                      f"CE: {loss_dict['loss_ce'].item():.4f}, "
                      f"Mask: {loss_dict['loss_mask'].item():.4f}, "
                      f"Dice: {loss_dict['loss_dice'].item():.4f}")

        # 计算 Epoch 的平均 Loss
        avg_loss = epoch_loss / len(dataloader)
        print(f"Epoch [{epoch + 1}/{epochs}] 结束,平均 Loss: {avg_loss:.4f}\n")

    # 7. 训练结束,保存模型权重
    torch.save(model.state_dict(), "reasoning_rfdetr_best.pth")
    print("训练完成 模型权重已保存为 reasoning_rfdetr_best.pth")


if __name__ == "__main__":
    train()

5.7 inference.py —— 玩具数据推理脚本

角色: 玩具数据推理验证。

功能: 加载训练好的 reasoning_rfdetr_best.pth 权重,取玩具数据集的第0张图进行前向推理。对50个Query的预测logits做softmax得到前景概率,取置信度最高的Query对应的Mask,用sigmoid转换后以0.5为阈值二值化。将原图逆归一化后与预测Mask叠加可视化,左侧显示原图,右侧显示红色半透明Mask叠加结果。

输出: inference_result.png

运行命令: python inference.py

完整代码:

import torch
import json
import matplotlib.pyplot as plt
import torchvision.transforms as T
from PIL import Image
import numpy as np

# 导入我们写的模型
from model import RFDETRSegmentation


def run_inference():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"正在使用设备: {device}")

    # 1. 加载标注文件,随便挑一张图来测试 (比如第0张)
    with open("data/annotations.json", "r", encoding="utf-8") as f:
        data = json.load(f)

    test_sample = data[0]  # 测试第一张图
    img_path = test_sample["image_path"]
    text_prompt = test_sample["text"]
    print(f"\n[1] 测试图片: {img_path}")
    print(f"[2] 推理指令: {text_prompt}")

    # 2. 图像预处理
    image = Image.open(img_path).convert("RGB")
    transform = T.Compose([
        T.Resize((512, 512)),
        T.ToTensor(),
        T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    image_tensor = transform(image).unsqueeze(0).to(device)  # 增加 batch 维度

    # 3. 初始化模型并加载权重
    print("[3] 正在加载训练好的模型权重...")
    model = RFDETRSegmentation(num_queries=50).to(device)
    # 加载权重 (严格匹配)
    model.load_state_dict(torch.load("reasoning_rfdetr_best.pth", map_location=device))
    model.eval()  # 切换到推理模式

    # 4. 执行前向推理
    print("[4] 模型思考与分割中...")
    with torch.no_grad():
        outputs = model(image_tensor, [text_prompt])

    # 5. 解析预测结果
    # outputs["pred_logits"] 形状: [1, 50, 2]
    # outputs["pred_masks"] 形状:  [1, 50, 512, 512]

    # 算出 50 个 Query 中,预测为前景(类别1)的概率
    probs = outputs["pred_logits"][0].softmax(dim=-1)
    foreground_scores = probs[:, 1]

    # 找到得分最高的那个 Query
    best_query_idx = foreground_scores.argmax().item()
    best_score = foreground_scores[best_query_idx].item()
    print(f"[5] 命中目标!最高置信度: {best_score:.4f} (来自 Query #{best_query_idx})")

    # 提取对应的 Mask,使用 Sigmoid 将其转换到 0~1,并二值化 (阈值0.5)
    best_mask = outputs["pred_masks"][0, best_query_idx].sigmoid()
    binary_mask = (best_mask > 0.5).cpu().numpy()

    # 6. 结果可视化
    # 还原原图用于展示
    img_show = image_tensor[0].cpu().permute(1, 2, 0).numpy()
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    img_show = std * img_show + mean
    img_show = np.clip(img_show, 0, 1)

    plt.figure(figsize=(10, 5))

    # 左边:原图
    plt.subplot(1, 2, 1)
    plt.title("Original Image", fontsize=12)
    plt.imshow(img_show)
    plt.axis("off")

    # 右边:预测的Mask叠加在原图上
    plt.subplot(1, 2, 2)
    plt.title(f"Predicted Mask (Score: {best_score:.2f})", fontsize=12)
    plt.imshow(img_show)
    # 用半透明的红色(Reds)覆盖在原图上
    plt.imshow(binary_mask, cmap='Reds', alpha=0.6)
    plt.axis("off")

    plt.tight_layout()
    plt.savefig("inference_result.png")
    print("\n推理完成 可视化结果已保存为 inference_result.png")


if __name__ == "__main__":
    run_inference()

5.8 auto_labeling_pipeline.py —— 自动打标流水线

角色: 最核心的数据生产引擎,实现零人工标注。

功能: 遍历 raw_waiting_room_photos/ 目录中的所有原始图片,对每张图片和每条推理指令执行三级流水线,最终生成可训练的数据集。

三级流水线详解:

第一级:MLLM推理(step1_mllm_reasoning方法)。 将图片编码为base64格式,连同中文推理指令一起发送给GPT-4o API。系统提示词要求大模型只输出极短的英文名词短语,不能输出完整句子。大模型返回英文实体词后,代码进行防杠精过滤:如果输出包含"None"、“unable”、"cannot"等词,或者输出超过15个词,则判定为无法识别,跳过该条数据。这是保证数据集质量的核心防线。

第二级:Grounding DINO定位(step2_grounding_dino方法)。 将图像和英文实体词送入Grounding DINO模型进行开放世界目标检测。使用box_threshold=0.35和text_threshold=0.25过滤低置信度检测结果。代码包含try-except兼容机制,处理不同版本transformers库的API差异。输出为Bounding Box列表和对应的置信度分数。

第三级:SAM分割(step3_sam_segmentation方法)。 将原图转换为NumPy数组送入SAM的predictor,将DINO输出的Box坐标通过SAM的transform模块转换到模型坐标系,调用predict_torch进行分割。多个物体的Mask合并为一个整体Mask,保存为PNG文件。

最终输出:waiting_room_reasoning_dataset/ 目录下生成 images/ 子目录(去重后的原图)、masks/ 子目录(合并后的Mask)和 annotations.json 标注文件。每条标注包含 image_pathmask_pathtext(中文原指令)、mllm_visual_prompt(MLLM输出的英文实体词)、bboxes(DINO检测到的所有Bounding Box坐标)。

运行命令: python auto_labeling_pipeline.py

注意事项: 运行前需在脚本底部配置API_KEY和BASE_URL,需将原始图片放入 raw_waiting_room_photos/ 目录,需在 reasoning_tasks 列表中定义推理指令。

完整代码:

import os
import cv2
import json
import torch
import base64
import numpy as np
from PIL import Image
from tqdm import tqdm
from openai import OpenAI
from huggingface_hub import snapshot_download
import warnings

warnings.filterwarnings("ignore")

os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "0"

from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection
from segment_anything import sam_model_registry, SamPredictor

class AutoLabelingPipeline:
    def __init__(self, api_key, base_url, sam_checkpoint="weights/sam_vit_h_4b8939.pth"):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        print(f"初始化自动化打标签流水线,使用设备: {self.device}")

        self.client = OpenAI(api_key=api_key, base_url=base_url)
        self.model_name = "gpt-4o"

        print("正在加载 Grounding DINO...")
        local_dino_path = "./models/grounding-dino-base"

        if not os.path.exists(local_dino_path) or not os.listdir(local_dino_path):
            print(f"首次运行,正在从国内镜像下载模型...")
            os.makedirs(local_dino_path, exist_ok=True)
            snapshot_download(
                repo_id="IDEA-Research/grounding-dino-base",
                local_dir=local_dino_path,
                local_dir_use_symlinks=False,
                resume_download=True,
                tqdm_class=tqdm,
                endpoint="https://hf-mirror.com"
            )

        self.dino_processor = AutoProcessor.from_pretrained(local_dino_path)
        self.dino_model = AutoModelForZeroShotObjectDetection.from_pretrained(local_dino_path).to(self.device)

        print("正在加载 SAM (Segment Anything)...")
        sam = sam_model_registry["vit_h"](checkpoint=sam_checkpoint)
        sam.to(device=self.device)
        self.sam_predictor = SamPredictor(sam)

    def encode_image(self, image_path):
        with open(image_path, "rb") as f:
            return base64.b64encode(f.read()).decode('utf-8')

    def step1_mllm_reasoning(self, image_path, reasoning_prompt):
        base64_image = self.encode_image(image_path)

        system_prompt = """
        你是一个冰冷的视觉实体提取机器。用户会给你图和任务。
        【强制规则】:
        1. 只能输出极短的英文名词短语!
        2. 如果你认为图里没有符合的物体,或者无法判断,请严格只输出一个单词:"None"
        3. 绝对不要输出任何完整的句子(如 I am unable to, The image shows 等)。
        正确示例:"blue sofa, wooden table"
        错误示例:"I can see a blue sofa and a wooden table."
        """

        response = self.client.chat.completions.create(
            model=self.model_name,
            messages=[
                {"role": "system", "content": system_prompt},
                {"role": "user", "content": [
                    {"type": "text", "text": f"推理任务:{reasoning_prompt}"},
                    {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"}}
                ]}
            ],
            temperature=0.1
        )
        entities_str = response.choices[0].message.content.strip()
        print(f"MLLM 思考结果: {entities_str}")

        # 防杠精机制:过滤大模型的废话
        lower_str = entities_str.lower()
        if "none" in lower_str or "unable" in lower_str or "cannot" in lower_str or len(entities_str.split()) > 15:
            print("MLLM 无法识别或未按格式输出,跳过该指令。")
            return ""

        return entities_str.replace(",", ".") + "."

    def step2_grounding_dino(self, image_path, text_prompt, box_threshold=0.35, text_threshold=0.25):
        image = Image.open(image_path).convert("RGB")
        inputs = self.dino_processor(images=image, text=text_prompt, return_tensors="pt").to(self.device)

        with torch.no_grad():
            outputs = self.dino_model(**inputs)

        # 修复 Bug:去掉 box_threshold 参数,使用通用写法后手动过滤
        try:
            results = self.dino_processor.post_process_grounded_object_detection(
                outputs,
                inputs.input_ids,
                box_threshold=box_threshold,
                text_threshold=text_threshold,
                target_sizes=[image.size[::-1]]
            )[0]
            boxes = results["boxes"].cpu().numpy()
            scores = results["scores"].cpu().numpy()
            labels = results["labels"]
        except TypeError:
            # 兼容老版本的 transformers 库
            results = self.dino_processor.post_process_grounded_object_detection(
                outputs,
                inputs.input_ids,
                target_sizes=[image.size[::-1]]
            )[0]
            # 手动执行阈值过滤
            keep_idx = results["scores"] > box_threshold
            boxes = results["boxes"][keep_idx].cpu().numpy()
            scores = results["scores"][keep_idx].cpu().numpy()
            labels = [results["labels"][i] for i, keep in enumerate(keep_idx) if keep]

        print(f" DINO 找到 {len(boxes)} 个目标框。")
        return boxes, scores, labels, image

    def step3_sam_segmentation(self, image_rgb, boxes):
        if len(boxes) == 0:
            return []

        image_np = np.array(image_rgb)
        self.sam_predictor.set_image(image_np)

        input_boxes = torch.tensor(boxes, device=self.sam_predictor.device)
        transformed_boxes = self.sam_predictor.transform.apply_boxes_torch(input_boxes, image_np.shape[:2])

        masks, _, _ = self.sam_predictor.predict_torch(
            point_coords=None,
            point_labels=None,
            boxes=transformed_boxes,
            multimask_output=False
        )

        final_masks = masks.squeeze(1).cpu().numpy()
        print(f"SAM 成功生成 {len(final_masks)} 个高精度 Mask。")
        return final_masks

    def process_dataset(self, image_dir, output_dir, reasoning_prompts):
        os.makedirs(output_dir, exist_ok=True)
        os.makedirs(os.path.join(output_dir, "images"), exist_ok=True)
        os.makedirs(os.path.join(output_dir, "masks"), exist_ok=True)

        dataset_annotations = []
        image_files = [f for f in os.listdir(image_dir) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]

        print(f"开始处理数据集,共 {len(image_files)} 张图片...")

        for img_idx, img_name in enumerate(tqdm(image_files)):
            img_path = os.path.join(image_dir, img_name)

            for prompt_idx, prompt in enumerate(reasoning_prompts):
                print(f"\n处理 [{img_name}] - 指令: {prompt}")
                try:
                    dino_prompt = self.step1_mllm_reasoning(img_path, prompt)

                    if not dino_prompt:
                        continue

                    boxes, scores, _, image_rgb = self.step2_grounding_dino(img_path, dino_prompt)

                    if len(boxes) == 0:
                        continue

                    masks = self.step3_sam_segmentation(image_rgb, boxes)

                    if len(masks) == 0:
                        continue

                    new_img_name = f"img_{img_idx:04d}.jpg"
                    new_img_path = os.path.join(output_dir, "images", new_img_name)
                    if not os.path.exists(new_img_path):
                        image_rgb.save(new_img_path)

                    combined_mask = np.zeros(masks[0].shape, dtype=np.uint8)
                    for mask in masks:
                        combined_mask[mask] = 255

                    mask_name = f"mask_{img_idx:04d}_p{prompt_idx}.png"
                    mask_path = os.path.join(output_dir, "masks", mask_name)
                    cv2.imwrite(mask_path, combined_mask)

                    dataset_annotations.append({
                        "image_path": new_img_path,
                        "mask_path": mask_path,
                        "text": prompt,
                        "mllm_visual_prompt": dino_prompt,
                        "bboxes": boxes.tolist()
                    })

                except Exception as e:
                    print(f"处理 {img_name} 时发生错误: {e}")
                    continue

        with open(os.path.join(output_dir, "annotations.json"), "w", encoding="utf-8") as f:
            json.dump(dataset_annotations, f, ensure_ascii=False, indent=4)

        print(f"数据集生成完毕 成功打标 {len(dataset_annotations)} 条数据,保存在 {output_dir}/ 目录下。")


if __name__ == "__main__":
    API_KEY = "your-api-key-here"
    BASE_URL = "your-api-base-url-here"

    pipeline = AutoLabelingPipeline(api_key=API_KEY, base_url=BASE_URL)

    INPUT_IMAGE_DIR = "raw_waiting_room_photos/"
    OUTPUT_DATASET_DIR = "waiting_room_reasoning_dataset/"

    reasoning_tasks = [
        "分割出所有可以移动的软装家具(用来把它们一键消除,看空房间的效果)",
        "找出所有的承重墙和不可拆卸的结构",
        "标出房间内下午会被阳光直射的地板区域(用来决定哪里放怕晒的植物)"
    ]

    pipeline.process_dataset(
        image_dir=INPUT_IMAGE_DIR,
        output_dir=OUTPUT_DATASET_DIR,
        reasoning_prompts=reasoning_tasks
    )

5.9 train_real_data.py —— 真实数据训练脚本

角色: 核心训练脚本,定义终极版模型并执行训练闭环。

核心类: AdvancedRFDETR。相比model.py中的RFDETRSegmentation,增加了2D空间位置编码。

AdvancedRFDETR架构详解:

视觉分支: ResNet50提取特征图,输出尺寸为[B,2048,H/32,W/32],通过1x1卷积投影到256维。然后注入2D空间位置编码:row_embedcol_embed是两个可学习参数,分别编码行位置和列位置信息,拼接后加到图像特征上,让Transformer具备空间方向感。最后将特征图展平为序列[B,256,H*W/1024]作为Transformer编码器的输入。

文本分支: 使用bert-base-chinese的Tokenizer和Encoder,将文本编码后通过线性层投影到256维,对所有token做mean-pooling得到全局文本向量[B,1,256]。

跨模态融合: 50个可学习的Query Embedding加上全局文本向量后作为Transformer解码器的输入。Transformer编码器处理图像序列,解码器通过交叉注意力机制让每个Query关注图像中与文本相关的区域。

预测头: class_head输出[B,50,2]的分类logits,bbox_head输出[B,50,4]的归一化坐标,mask_proj将解码输出投影后与图像特征图做einsum点积生成[B,50,H/32,W/32]的Mask,最后通过双线性插值上采样回[B,50,512,512]。

训练配置: 加载 waiting_room_reasoning_dataset/annotations.json 真实数据集,冻结ResNet50和BERT骨干网络,AdamW优化器(学习率1e-4),训练80个epoch,batch_size为2,图像尺寸512。损失权重为loss_ce权重1.0、loss_mask权重5.0、loss_dice权重2.0。使用梯度裁剪(max_norm=0.1)防止Transformer训练初期的梯度爆炸。

输出: advanced_rfdetr_waiting_room.pth

运行命令: python train_real_data.py

完整代码:

import os
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
from PIL import Image
import torch.optim as optim
import warnings
from transformers import AutoTokenizer, AutoModel
from huggingface_hub import snapshot_download

warnings.filterwarnings("ignore")
# 强制使用国内镜像站
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
# 注释掉离线模式,允许首次完整下载
# os.environ["TRANSFORMERS_OFFLINE"] = "1"

# 引入之前的 Loss 函数
from loss import HungarianMatcher, SetCriterion


# 1. 适配新数据集的 Dataset
class RealReasoningDataset(Dataset):
    def __init__(self, json_path, img_size=512):
        with open(json_path, 'r', encoding='utf-8') as f:
            self.annotations = json.load(f)
        self.img_size = img_size
        self.img_transform = T.Compose([
            T.Resize((img_size, img_size)),
            T.ToTensor(),
            T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        self.mask_transform = T.Compose([
            T.Resize((img_size, img_size), interpolation=T.InterpolationMode.NEAREST),
            T.ToTensor()
        ])

    def __len__(self):
        return len(self.annotations)

    def __getitem__(self, idx):
        anno = self.annotations[idx]
        image = Image.open(anno['image_path']).convert('RGB')
        mask = Image.open(anno['mask_path']).convert('L')
        text = anno['text']

        image_tensor = self.img_transform(image)
        mask_tensor = self.mask_transform(mask)
        mask_tensor = (mask_tensor > 0.5).float()  # 二值化
        return image_tensor, mask_tensor, text


# 2. RF-DETR (加入 2D 空间位置编码)
class AdvancedRFDETR(nn.Module):
    def __init__(self, num_queries=50, hidden_dim=256):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.num_queries = num_queries

        # 本地化加载 BERT
        local_bert_path = "./models/bert-base-chinese"
        if not os.path.exists(local_bert_path) or not os.listdir(local_bert_path):
            print("从国内镜像补全 BERT 模型 ...")
            os.makedirs(local_bert_path, exist_ok=True)
            snapshot_download(
                repo_id="bert-base-chinese",
                local_dir=local_bert_path,
                local_dir_use_symlinks=False,
                resume_download=True,
                endpoint="https://hf-mirror.com"
            )
            print(" BERT 下载完成!")

        print(f"从本地加载 BERT: {local_bert_path}")
        self.tokenizer = AutoTokenizer.from_pretrained(local_bert_path)
        self.text_encoder = AutoModel.from_pretrained(local_bert_path)
        self.text_proj = nn.Linear(768, hidden_dim)


        # 图像特征
        resnet = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
        self.backbone = nn.Sequential(*list(resnet.children())[:-2])
        self.conv_proj = nn.Conv2d(2048, hidden_dim, 1)

        # 二维相对位置编码
        self.row_embed = nn.Parameter(torch.rand(50, hidden_dim // 2))
        self.col_embed = nn.Parameter(torch.rand(50, hidden_dim // 2))

        # Transformer
        self.query_embed = nn.Embedding(num_queries, hidden_dim)
        self.transformer = nn.Transformer(
            d_model=hidden_dim, nhead=8, num_encoder_layers=4, num_decoder_layers=4, batch_first=True
        )

        # 预测头
        self.class_head = nn.Linear(hidden_dim, 2)
        self.bbox_head = nn.Sequential(nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, 4))
        self.mask_proj = nn.Sequential(nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim))

    def forward(self, images, reasoning_texts):
        B = images.shape[0]
        device = images.device

        text_inputs = self.tokenizer(reasoning_texts, return_tensors="pt", padding=True, truncation=True,
                                     max_length=64).to(device)
        text_features = self.text_proj(self.text_encoder(**text_inputs).last_hidden_state)
        text_global = text_features.mean(dim=1).unsqueeze(1)

        img_features = self.backbone(images)
        H, W = img_features.shape[-2:]
        img_features = self.conv_proj(img_features)

        pos_x = self.col_embed[:W].unsqueeze(0).repeat(H, 1, 1)
        pos_y = self.row_embed[:H].unsqueeze(1).repeat(1, W, 1)
        pos_encoding = torch.cat([pos_y, pos_x], dim=-1).permute(2, 0, 1).unsqueeze(0).repeat(B, 1, 1, 1)

        img_seq = (img_features + pos_encoding).flatten(2).permute(0, 2, 1)

        queries = self.query_embed.weight.unsqueeze(0).repeat(B, 1, 1) + text_global
        hs = self.transformer(src=img_seq, tgt=queries)

        outputs_class = self.class_head(hs)
        outputs_coord = self.bbox_head(hs).sigmoid()

        mask_embeddings = self.mask_proj(hs)
        outputs_mask = torch.einsum("bnc,bchw->bnhw", mask_embeddings, img_features)
        outputs_mask = F.interpolate(outputs_mask, size=images.shape[-2:], mode="bilinear", align_corners=False)

        return {"pred_logits": outputs_class, "pred_boxes": outputs_coord, "pred_masks": outputs_mask}

    def freeze_backbones(self):
        for param in self.backbone.parameters(): param.requires_grad = False
        for param in self.text_encoder.parameters(): param.requires_grad = False


# 3. 真实数据训练闭环
def train_real_data():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"开始在真实数据集上训练!使用设备: {device}")

    json_path = "waiting_room_reasoning_dataset/annotations.json"

    dataset = RealReasoningDataset(json_path=json_path, img_size=512)
    dataloader = DataLoader(dataset, batch_size=2, shuffle=True)
    print(f"数据加载完成,共有 {len(dataset)} 个真实样本。")

    model = AdvancedRFDETR(num_queries=50).to(device)
    model.freeze_backbones()
    model.train()

    matcher = HungarianMatcher()
    criterion = SetCriterion(matcher).to(device)
    optimizer = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-4)

    epochs = 80
    for epoch in range(epochs):
        epoch_loss = 0.0
        for step, (images, masks, texts) in enumerate(dataloader):
            images, masks = images.to(device), masks.to(device)

            targets = [{"labels": torch.tensor([1], dtype=torch.int64, device=device), "masks": masks[i]} for i in
                       range(images.shape[0])]
            outputs = model(images, texts)

            loss_dict = criterion(outputs, targets)
            loss = loss_dict['loss_ce'] * 1.0 + loss_dict['loss_mask'] * 5.0 + loss_dict['loss_dice'] * 2.0

            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.1)
            optimizer.step()
            epoch_loss += loss.item()

        print(f"Epoch [{epoch + 1}/{epochs}] 平均 Loss: {epoch_loss / len(dataloader):.4f}")

    torch.save(model.state_dict(), "advanced_rfdetr_waiting_room.pth")
    print("真实场景模型训练完成 权重已保存。")


if __name__ == "__main__":
    train_real_data()

5.10 inference_real_data.py —— 真实数据推理脚本

角色: 最终验收推理脚本。

功能: 加载训练好的 advanced_rfdetr_waiting_room.pth 权重,取真实数据集的第13张图(可修改索引)进行前向推理。对50个Query的预测logits做softmax得到前景概率,取置信度最高的Query对应的Mask,用sigmoid转换后以0.5为阈值二值化。将原图逆归一化后与预测Mask叠加可视化,使用jet伪彩色映射让分割区域更加醒目。

输出: inference_real_result.png

运行命令: python inference_real_data.py

完整代码:

import os
import torch
import json
import matplotlib.pyplot as plt
import torchvision.transforms as T
from PIL import Image
import numpy as np
import warnings

warnings.filterwarnings("ignore")

# 直接从训练脚本中导入带有空间位置编码的终极版模型
from train_real_data import AdvancedRFDETR


def run_real_inference():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"正在使用设备: {device}")

    # 1. 加载我们自动生成的标注文件,挑一张图来测试 (比如第 0 张)
    dataset_dir = "waiting_room_reasoning_dataset"
    with open(os.path.join(dataset_dir, "annotations.json"), "r", encoding="utf-8") as f:
        data = json.load(f)

    # 修改这里的索引 (0 到 12),看看不同图片的预测效果
    test_sample = data[13]
    img_path = test_sample["image_path"]
    # 既可以使用原来的中文指令,也可以试试 DINO 提取的英文特征
    text_prompt = test_sample["text"]

    print(f"\n[1] 测试图片: {img_path}")
    print(f"[2] 推理指令: {text_prompt}")

    # 2. 图像预处理
    image = Image.open(img_path).convert("RGB")
    transform = T.Compose([
        T.Resize((512, 512)),
        T.ToTensor(),
        T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    image_tensor = transform(image).unsqueeze(0).to(device)

    # 3. 初始化模型并加载权重
    print("[3] 正在加载候诊室专属模型权重...")
    model = AdvancedRFDETR(num_queries=50).to(device)
    model.load_state_dict(torch.load("advanced_rfdetr_waiting_room.pth", map_location=device))
    model.eval()  # 切换到推理模式

    # 4. 执行前向推理
    print("[4] 模型思考与分割中...")
    with torch.no_grad():
        outputs = model(image_tensor, [text_prompt])

    # 5. 解析预测结果
    probs = outputs["pred_logits"][0].softmax(dim=-1)
    foreground_scores = probs[:, 1]

    best_query_idx = foreground_scores.argmax().item()
    best_score = foreground_scores[best_query_idx].item()
    print(f"[5] 命中目标!最高置信度: {best_score:.4f} (来自 Query #{best_query_idx})")

    # 提取对应的 Mask
    best_mask = outputs["pred_masks"][0, best_query_idx].sigmoid()
    binary_mask = (best_mask > 0.5).cpu().numpy()

    # 6. 结果可视化
    img_show = image_tensor[0].cpu().permute(1, 2, 0).numpy()
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    img_show = std * img_show + mean
    img_show = np.clip(img_show, 0, 1)

    # 画图
    plt.figure(figsize=(12, 6))

    # 原图
    plt.subplot(1, 2, 1)
    plt.title("Original Waiting Room", fontsize=12)
    plt.imshow(img_show)
    plt.axis("off")

    # 预测的 Mask 叠加
    plt.subplot(1, 2, 2)
    plt.title(f"Predicted Mask (Score: {best_score:.2f})", fontsize=12)
    plt.imshow(img_show)
    plt.imshow(binary_mask, cmap='jet', alpha=0.5)  # 使用 jet 伪彩色,看起来更清晰
    plt.axis("off")

    plt.tight_layout()
    plt.savefig("inference_real_result.png", bbox_inches='tight', dpi=150)
    print("\n推理完成 可视化结果已保存为 inference_real_result.png")


if __name__ == "__main__":
    run_real_inference()

六、执行顺序

整个项目的执行分为两个阶段,建议先跑通玩具数据验证链,再跑真实数据生产链。

阶段零:环境准备(一次性操作)

安装Python依赖包,包括torch、torchvision、transformers、opencv-python、pillow、openai、tqdm、matplotlib、huggingface_hub、scipy。安装Segment Anything官方库。手动下载sam_vit_h_4b8939.pth权重文件放入weights/目录。首次运行脚本时会自动从国内镜像下载grounding-dino-base和bert-base-chinese到models/目录。

阶段A:玩具数据验证链(快速验证流水线是否跑通)

第一步: 运行 python generate_toy_data.py,生成20张几何图形假图和对应的Mask,保存在data/目录。

第二步: dataset.py被其他脚本导入使用,无需单独执行。

第三步: 运行 python test_dataloader.py,验证DataLoader输出的维度是否正确,可视化图片和Mask,确认数据管道无误。

第四步: 运行 python train.py,在玩具数据上训练20个epoch,生成reasoning_rfdetr_best.pth权重文件。

第五步: 运行 python inference.py,加载权重对玩具数据进行推理,生成inference_result.png可视化结果。至此玩具数据闭环完成,证明数据管道、模型、损失函数、推理全链路跑通。

阶段B:真实数据生产链

第六步: 运行 python auto_labeling_pipeline.py。运行前需在脚本底部配置API_KEY和BASE_URL,将原始图片放入raw_waiting_room_photos/目录,在reasoning_tasks列表中定义推理指令。脚本遍历所有原始图片,对每张图执行MLLM推理、DINO定位、SAM分割三级流水线,生成waiting_room_reasoning_dataset/数据集。这一步最耗时,60张图乘以3条指令最多180次API调用。

第七步: 运行 python train_real_data.py,在真实数据上训练80个epoch,生成advanced_rfdetr_waiting_room.pth权重文件。

第八步: 运行 python inference_real_data.py,加载权重对真实数据进行推理,生成inference_real_result.png可视化结果。至此真实数据闭环完成。

脚本间的导入依赖关系

  • train.py 导入 dataset.py(ReasoningSegDataset类)、model.py(RFDETRSegmentation类)、loss.py(HungarianMatcher和SetCriterion类)
  • test_dataloader.py 导入 dataset.py
  • inference.py 导入 model.py
  • train_real_data.py 导入 loss.py,并在文件内部定义了RealReasoningDataset类和AdvancedRFDETR类
  • inference_real_data.pytrain_real_data.py 导入 AdvancedRFDETR类
  • auto_labeling_pipeline.pygenerate_toy_data.py 独立运行,无项目内部依赖

七、环境配置

7.1 硬件要求

项目 最低要求 推荐配置
GPU NVIDIA GPU, 8GB显存 12GB+显存
CUDA 11.8+ 12.0+
内存 16GB 32GB
硬盘 10GB(含模型权重) 20GB

7.2 软件依赖安装

第一步,安装PyTorch(CUDA 11.8版本):

pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118

第二步,安装基础依赖(使用清华镜像加速):

pip install transformers opencv-python pillow openai tqdm matplotlib huggingface_hub scipy -i https://pypi.tuna.tsinghua.edu.cn/simple

第三步,安装Segment Anything官方库:

pip install git+https://github.com/facebookresearch/segment-anything.git

7.3 模型权重下载

权重 大小 下载方式
sam_vit_h_4b8939.pth ~2.4GB 手动下载weights/
grounding-dino-base ~700MB 首次运行自动下载(国内镜像)
bert-base-chinese ~400MB 首次运行自动下载(国内镜像)

7.4 API配置

auto_labeling_pipeline.py 脚本底部修改 API_KEYBASE_URL 变量。API_KEY填写你的大模型API密钥,BASE_URL填写兼容OpenAI格式的API地址。


八、快速开始

最简流程(跑通真实数据只需四步)

第一步: 将你的原始图片放入 raw_waiting_room_photos/ 目录。

第二步: 编辑 auto_labeling_pipeline.py 底部的 API_KEYBASE_URL,然后运行:

python auto_labeling_pipeline.py

第三步: 运行训练:

python train_real_data.py

第四步: 运行推理查看效果:

python inference_real_data.py

调试流程(用玩具数据快速验证)

依次运行:

python generate_toy_data.py
python test_dataloader.py
python train.py
python inference.py

九、核心算法解析

9.1 匈牙利匹配

DETR架构的核心挑战是模型输出50个Query预测,但每张图通常只有1个真实目标,需要确定哪个预测对应真实目标来计算损失。

解决方法是构建代价矩阵并使用匈牙利算法找最优匹配。代价矩阵由三部分组成:类别代价(预测类别概率的负对数)、BCE代价(预测Mask与真实Mask的二值交叉熵)、Dice代价(1减去Dice系数)。三个代价加权求和得到总代价矩阵,然后使用scipy的linear_sum_assignment算法找出代价最小的Query作为正样本匹配。匹配成功的Query计算分类损失和分割损失,其余49个Query标记为背景并计算背景惩罚损失(背景权重降低到0.1)。

9.2 Grounding DINO与SAM的分工

Grounding DINO和SAM虽然都涉及图像中的目标定位,但分工完全不同,是串联互补关系而非冗余。

维度 Grounding DINO SAM
输出 矩形Bounding Box 像素级Mask
精度 目标级(粗粒度) 像素级(细粒度)
输入 图像 + 文本 图像 + Box/Point提示
角色 “在哪里”(语义定位) “精确形状”(像素分割)
能否单独使用 能,但只有粗糙矩形框 不能,需要提示告诉它切哪里

为什么两个都要用: 只用DINO只能得到粗糙的矩形框,无法做像素级分割。只用SAM则不知道要分割哪个物体,需要DINO提供Box提示。二者串联后,DINO负责开放世界理解和粗定位,SAM负责闭集精细分割,形成完整的"理解→定位→分割"流水线。

9.3 2D空间位置编码

Transformer本身不具备空间位置感知能力,需要显式注入位置信息。AdvancedRFDETR使用可学习的2D位置编码:row_embed编码行位置(50个位置各128维),col_embed编码列位置(50个位置各128维)。对于特征图上的每个位置,将其行编码和列编码拼接得到256维的位置向量,加到图像特征上。这样Transformer就能区分"左上角"和"右下角"的特征,具备空间方向感。

9.4 损失函数组合

训练使用三种损失的加权组合:

损失 权重 作用
loss_ce (交叉熵) 1.0 分类:前景 vs 背景
loss_mask (BCE) 5.0 像素级二分类
loss_dice (Dice) 2.0 优化Mask边缘重叠率

十、效果展示

文件 说明
dataloader_test_result.png DataLoader输出的图像和Mask并排可视化,用于验证数据管道正确性
inference_result.png 玩具数据推理结果,展示模型对几何图形的分割效果
inference_real_result.png 真实候诊室场景推理结果,展示模型在实际场景中的分割效果

十一、已知问题与改进方向

11.1 当前主要问题

问题一(最高优先级):训练数据量不足。 当前数据集约180条标注数据,而DETR类模型通常需要数千到数万样本才能充分收敛。数据量少导致模型容易过拟合训练集,泛化能力差。

问题二(最高优先级):Focal Loss已实现但未启用。 sigmoid_focal_loss函数已在loss.py中完整实现,但SetCriterion中仍使用普通的交叉熵损失。50个Query中只有1个是前景,类别极度不平衡,Focal Loss能有效缓解此问题。

问题三(高优先级):文本注入方式过于简单。 当前将BERT的全局向量通过mean-pooling后直接加到所有50个Query上,所有Query获得完全相同的文本信息,丧失了Query多样性。更好的做法是使用Cross-Attention让每个Query自主关注文本的不同部分。

问题四(高优先级):多目标场景强制单标签。 当前代码假设每张图只有1个目标(labels固定为[1]),但实际数据中一张图可能包含多个物体(沙发、茶几、椅子、绿植等)。合并后的Mask包含多个不连通区域,匈牙利匹配只分配1个Query来拟合这个复杂组合,非常困难。

问题五(中优先级):位置编码使用随机初始化。 row_embedcol_embed使用torch.rand均匀随机初始化,对于位置编码来说远不如正弦编码或学习的正弦初始化稳定,加上数据量少,这些参数难以收敛到有意义的值。

问题六(中优先级):Mask分辨率低。 图像特征图来自ResNet50最后一层,stride为32,512输入对应16x16的特征图。虽然通过双线性插值上采样回原尺寸,但细节已经丢失,小物体在特征图上可能只占1到2个像素。

问题七(中优先级):标签链路误差累积。 训练标签经过GPT-4o到Grounding DINO到SAM三级传递,每一环都可能引入误差(MLLM幻觉、DINO漏检误检、SAM边界噪声),这些误差会固化进训练标签中。

11.2 改进方向

  • 数据扩充: 增加原始图片数量,目标将训练样本从180条提升到500条以上。添加数据增强模块,包括图像翻转、色彩扰动、随机裁剪,以及文本同义改写。
  • 模型改进: 将文本注入方式从简单的mean-pooling加法改为Cross-Attention机制。引入FPN多尺度特征融合,让Mask Head能利用高分辨率特征。将位置编码从随机初始化改为正弦编码初始化。
  • 损失优化: 在SetCriterion中启用已实现的sigmoid_focal_loss替换普通交叉熵损失。修改targets构建逻辑,支持一张图多个独立Mask的多目标匹配。
  • 训练监控: 划分训练集和验证集,每个epoch计算验证集上的mIoU指标。添加TensorBoard日志记录,跟踪损失曲线和指标变化,防止过拟合。

十二、技术栈

类别 技术
深度学习框架 PyTorch, torchvision
预训练模型 ResNet50, BERT (bert-base-chinese)
开放世界检测 Grounding DINO (transformers)
图像分割 Segment Anything (SAM, vit_h)
多模态大模型 GPT-4o (OpenAI API)
优化算法 AdamW, 匈牙利算法 (scipy)
数据处理 OpenCV, PIL, NumPy
可视化 Matplotlib

License

本项目仅供学习和研究使用。

致谢

Logo

一站式 AI 云服务平台

更多推荐