Reasoning RF-DETR:基于多模态推理的开放世界图像分割系统(附完整代码)
Reasoning RF-DETR:基于多模态推理的开放世界图像分割系统(附完整代码)
一个结合 MLLM(多模态大语言模型)、Grounding DINO 和 SAM 自动构建数据集,并训练 RF-DETR 端到端分割模型的完整闭环项目。
目录
- 一、项目简介
- 二、核心创新点
- 三、系统架构总览
- 四、项目目录结构
- 五、各脚本详细说明与完整代码
- 六、执行顺序
- 七、环境配置
- 八、快速开始
- 九、核心算法解析
- 十、效果展示
- 十一、已知问题与改进方向
- 十二、技术栈
一、项目简介
1.1 这是什么项目
本项目实现了一条**“零代码全自动打标 → 端到端带空间感知的分割模型训练 → 工业级推理”**的完整闭环流水线。
用户只需输入一句自然语言指令(如"分割出所有可以移动的软装家具"),系统就能完成以下工作:调用大模型理解指令语义,自动在图像中定位目标,生成像素级分割掩码,并训练一个轻量级端到端模型来完成同样的任务。
注意:当前模型为简易测试版本,用于快速验证流水线可行性。 模型采用基础的 ResNet50 + BERT + Transformer 架构,文本注入方式为简单的 mean-pooling 加法,Mask Head 使用单尺度特征图点积。该版本在少量数据上可跑通全流程,但分割精度和泛化能力有限。如需用于实际场景,建议参考第十一节中的改进方向进行优化。
1.2 解决什么问题
传统图像分割模型需要大量人工标注的像素级Mask,标注成本极高。本项目通过 MLLM + Grounding DINO + SAM 三级流水线实现零人工标注,大幅降低数据构建成本。
1.3 应用场景
- 候诊室和办公室软装识别与一键消除
- 工地危险物检测与分割
- 桥梁和机械结构缺陷定位
- 室内设计辅助(光照区域分析、家具布局规划)
二、核心创新点
创新点一:三级自动打标流水线。 MLLM负责语义理解,Grounding DINO负责开放世界定位,SAM负责像素级分割,三者串联实现零人工标注。
创新点二:推理式文本引导分割。 不是简单的"分割猫",而是接受复杂推理指令,例如"找出下午会被阳光直射的地板区域"。
创新点三:轻量级RF-DETR架构。 使用ResNet50作为视觉骨干、BERT作为文本编码器、Transformer进行跨模态融合,冻结骨干网络后单卡即可训练。
创新点四:2D空间位置编码。 注入行列位置信息,让Transformer具备空间方向感。
创新点五:匈牙利二分图匹配。 解决50个Query无序输出与真实标签之间的梯度分配问题。
三、系统架构总览
整个系统分为两个解耦的阶段。
阶段一:数据自动构建(离线预处理)
原始图片和中文推理指令首先送入MLLM API(GPT-4o),大模型理解指令语义后输出英文实体词,例如输入"分割出所有可以移动的软装家具",输出"chairs, coffee table, plant"。这些英文实体词送入Grounding DINO,它在图像上预测出多个Bounding Box,将沙发、茶几、绿植等目标用矩形框锁定。这些Box坐标送入SAM,SAM沿着物体的实际边缘切出像素级的高清Mask。最后,脚本自动将原图、中文原指令、合并后的高清Mask、所有Bbox坐标打包写入JSON文件,形成可训练的数据集。
阶段二:端到端模型训练与推理
AdvancedRFDETR模型接收图像和文本指令作为输入。ResNet50提取视觉特征并注入2D空间位置编码,BERT提取文本语义特征并与Transformer的Object Queries融合,Transformer编解码器计算跨模态注意力,Mask Head将解码输出与图像特征图进行点积生成最终Mask。训练时使用匈牙利匹配器找到与真实标签最匹配的Query,对该Query计算分类损失和分割损失,对其余Query计算背景惩罚损失。推理时取置信度最高的Query对应的Mask作为最终输出。
四、项目目录结构
Reasoning_RFDETR/
│
├── weights/ # 存放基础模型权重
│ └── sam_vit_h_4b8939.pth # (必须手动下载) SAM的视觉大模型权重
│
├── models/ # 存放自动下载的预训练模型
│ ├── grounding-dino-base/ # Grounding DINO权重
│ └── bert-base-chinese/ # BERT中文权重
│
├── raw_waiting_room_photos/ # 【输入】原始待标注图片
│ └── *.jpg # 60张候诊室照片
│
├── waiting_room_reasoning_dataset/ # 【输出】自动生成的真实数据集
│ ├── images/ # 去重后的原图
│ ├── masks/ # 合并后的像素级Mask
│ └── annotations.json # 标注文件
│
├── data/ # 玩具数据集(调试用)
│ ├── images/
│ ├── masks/
│ └── annotations.json
│
├── auto_labeling_pipeline.py # 【核心】自动打标流水线
├── train_real_data.py # 【核心】真实数据训练脚本
├── inference_real_data.py # 【核心】真实数据推理脚本
├── loss.py # 损失函数与匈牙利匹配器
├── model.py # 早期版本模型定义
├── dataset.py # 玩具数据集Dataset类
├── train.py # 玩具数据训练脚本
├── inference.py # 玩具数据推理脚本
├── generate_toy_data.py # 玩具数据生成脚本
├── test_dataloader.py # DataLoader调试脚本
│
├── reasoning_rfdetr_best.pth # 玩具数据训练权重 (~528MB)
├── advanced_rfdetr_waiting_room.pth # 真实数据训练权重 (~528MB)
├── inference_result.png # 玩具数据推理结果
├── inference_real_result.png # 真实数据推理结果
├── dataloader_test_result.png # DataLoader测试可视化
│
├── README.MD # 部署使用手册
├── PROJECT_GUIDE.md # 项目详尽说明
└── 代码介绍.txt # 开发历史记录
五、各脚本详细说明与完整代码
5.1 generate_toy_data.py —— 玩具数据生成器
角色: 调试辅助脚本。
功能: 用OpenCV自动生成几何图形作为假图,同时生成对应的白底黑块Mask,并配上中文推理文字。随机生成红、绿、蓝三种颜色的矩形,在图片上画实心矩形,在Mask上同一位置设为白色(代表目标),背景为黑色。为每张图生成对应的中文描述文本,如"画面中那个红色的方形物体"。所有数据保存到 data/ 目录下。
输入: 无,自动随机生成。
输出: data/images/ 目录下的JPG图片、data/masks/ 目录下的PNG掩码、data/annotations.json 标注文件。
运行命令: python generate_toy_data.py
完整代码:
# 自动生成测试数据的脚本
import os
import cv2
import json
import numpy as np
def create_toy_dataset(num_samples=10):
# 创建目录
os.makedirs("data/images", exist_ok=True)
os.makedirs("data/masks", exist_ok=True)
annotations = []
print("开始生成 Toy Dataset...")
for i in range(num_samples):
# 1. 创建一张黑色背景的图片 (512x512)
img = np.zeros((512, 512, 3), dtype=np.uint8)
mask = np.zeros((512, 512), dtype=np.uint8) # Mask是单通道灰度图
# 随机生成一个矩形的位置
x1, y1 = np.random.randint(50, 200, 2)
w, h = np.random.randint(100, 200, 2)
x2, y2 = x1 + w, y1 + h
# 随机决定颜色 (红、绿、蓝) 和对应的文字描述
color_choice = np.random.choice(['红色', '绿色', '蓝色'])
color_bgr = (0, 0, 255) if color_choice == '红色' else \
(0, 255, 0) if color_choice == '绿色' else \
(255, 0, 0)
# 在图片上画实心矩形
cv2.rectangle(img, (x1, y1), (x2, y2), color_bgr, -1)
# 在Mask上,该区域设为255(白色,代表目标),背景为0(黑色)
cv2.rectangle(mask, (x1, y1), (x2, y2), 255, -1)
# 文件路径
img_name = f"image_{i:03d}.jpg"
mask_name = f"mask_{i:03d}.png"
img_path = os.path.join("data/images", img_name)
mask_path = os.path.join("data/masks", mask_name)
# 保存图片和Mask
cv2.imwrite(img_path, img)
cv2.imwrite(mask_path, mask)
# 模拟需要推理的文本(实际任务中,这里是MLLM输出的描述)
reasoning_text = f"画面中那个{color_choice}的方形物体"
annotations.append({
"image_path": img_path,
"mask_path": mask_path,
"text": reasoning_text
})
# 保存 JSON 标注文件
with open("data/annotations.json", "w", encoding="utf-8") as f:
json.dump(annotations, f, ensure_ascii=False, indent=4)
print(f"成功生成 {num_samples} 条数据!保存在 data/ 目录下。")
if __name__ == "__main__":
create_toy_dataset(20) # 生成20张图用来测试
5.2 dataset.py —— 玩具数据集Dataset类
角色: 数据管道,被 train.py、inference.py 和 test_dataloader.py 导入使用。
功能: 定义 ReasoningSegDataset 类,继承自 torch.utils.data.Dataset。读取 JSON 标注文件,对图像进行 Resize到512、ToTensor转换、ImageNet标准归一化(均值[0.485,0.456,0.406],标准差[0.229,0.224,0.225])。对Mask进行 Resize到512(使用NEAREST插值保持二值性)、ToTensor转换、二值化(大于0.5的设为1,其余为0)。每次调用返回图像Tensor、Mask Tensor和文本字符串。
核心类: ReasoningSegDataset
完整代码:
#要把硬盘里的图片和文本喂给显卡,必须经过 torch.utils.data.Dataset。
import json
import torch
from torch.utils.data import Dataset
from PIL import Image
import torchvision.transforms as T
class ReasoningSegDataset(Dataset):
def __init__(self, json_path, img_size=512):
"""
json_path: 标注文件的路径
img_size: 统一缩放的尺寸
"""
with open(json_path, 'r', encoding='utf-8') as f:
self.annotations = json.load(f)
self.img_size = img_size
# 图片的预处理:缩放 -> 转为Tensor -> 归一化 (ImageNet标准)
self.img_transform = T.Compose([
T.Resize((img_size, img_size)),
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# Mask的预处理:缩放 -> 转为Tensor (不需要归一化,只有0和1)
self.mask_transform = T.Compose([
T.Resize((img_size, img_size), interpolation=T.InterpolationMode.NEAREST),
T.ToTensor()
])
def __len__(self):
return len(self.annotations)
def __getitem__(self, idx):
anno = self.annotations[idx]
# 1. 读取原图 (RGB)
img_path = anno['image_path']
image = Image.open(img_path).convert('RGB')
# 2. 读取 Mask (L 表示灰度图)
mask_path = anno['mask_path']
mask = Image.open(mask_path).convert('L')
# 3. 读取文本
text = anno['text']
# 4. 应用变换
image_tensor = self.img_transform(image)
mask_tensor = self.mask_transform(mask)
# 确保 mask_tensor 的值是 0 或 1 (原本是0或255,ToTensor会将其转为0-1)
mask_tensor = (mask_tensor > 0.5).float()
return image_tensor, mask_tensor, text
5.3 test_dataloader.py —— DataLoader调试脚本
角色: 调试验证脚本。
功能: 实例化Dataset和DataLoader,抓取一个batch的数据,打印图像和Mask的维度信息(期望图像为[2,3,512,512],Mask为[2,1,512,512]),打印文本内容。将第一张图逆归一化后与Mask并排可视化,保存为图片。目的是在进入训练之前确认数据管道完全正确,防止黑盒报错。
输出: dataloader_test_result.png
运行命令: python test_dataloader.py
完整代码:
#把数据抽出来打印一下维度,并且画出来看看。(防止黑盒报错)。
import numpy as np
import torch
from torch.utils.data import DataLoader
from dataset import ReasoningSegDataset
import matplotlib.pyplot as plt
def test_pipeline():
# 1. 实例化 Dataset
dataset = ReasoningSegDataset(json_path="data/annotations.json", img_size=512)
# 2. 实例化 DataLoader (batch_size=2 意味着每次吐出2张图)
dataloader = DataLoader(dataset, batch_size=2, shuffle=True)
print(f"数据集总大小: {len(dataset)}")
# 3. 抓取一个 Batch 的数据
images, masks, texts = next(iter(dataloader))
print("\n--- DataLoader 吐出的张量维度 ---")
print(f"Images Shape: {images.shape}") # 期望: [2, 3, 512, 512]
print(f"Masks Shape: {masks.shape}") # 期望: [2, 1, 512, 512]
print(f"Texts: {texts}") # 期望: 元组形式的两个字符串
# 4. 逆归一化并可视化 (选第一张图)
img_show = images[0].permute(1, 2, 0).numpy()
# ImageNet 逆归一化公式
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
img_show = std * img_show + mean
img_show = np.clip(img_show, 0, 1)
mask_show = masks[0].squeeze().numpy()
# 使用 matplotlib 画图
plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
plt.title("Original Image (Un-normalized)")
plt.imshow(img_show)
plt.axis('off')
plt.subplot(1, 2, 2)
plt.title("Ground Truth Mask")
plt.imshow(mask_show, cmap='gray')
plt.axis('off')
plt.suptitle(f"Text Prompt: {texts[0]}", fontsize=14)
plt.tight_layout()
# 保存可视化结果而不是弹窗(防止服务器环境下报错)
plt.savefig("dataloader_test_result.png")
print("\n可视化结果已保存为 dataloader_test_result.png")
if __name__ == "__main__":
test_pipeline()
5.4 loss.py —— 损失函数与匹配器
角色: 核心算法模块,解决RF-DETR架构中50个无序Query输出与真实标签之间的梯度计算问题。
包含三个核心组件:
组件一:HungarianMatcher(匈牙利匹配器)。 模型输出50个预测结果,但真实标签通常只有1个目标。匹配器首先计算每个预测与真实目标之间的代价矩阵,代价由三部分组成:类别代价(预测类别与真实类别的负对数概率)、BCE代价(预测Mask与真实Mask的二值交叉熵)、Dice代价(预测Mask与真实Mask的Dice系数)。然后使用scipy的linear_sum_assignment算法找出代价最小的那个Query作为正样本匹配,其余49个Query标记为背景。匹配过程在torch.no_grad下进行,不参与梯度计算。
组件二:SetCriterion(集合损失准则)。 基于匹配器的结果计算最终损失。对匹配成功的Query计算分类交叉熵损失(CE Loss)、掩码二值交叉熵损失(BCE Loss)和Dice损失。对未匹配的49个Query计算背景惩罚损失,背景类权重设为0.1以降低惩罚力度,防止网络将所有Query都预测为背景。损失权重比例为 loss_ce 权重1.0、loss_mask 权重5.0、loss_dice 权重2.0。
组件三:辅助损失函数。 dice_loss 专门用于优化分割掩码的边缘质量,计算预测Mask和真实Mask的Dice系数(2倍交集除以并集),损失值为1减去Dice系数。sigmoid_focal_loss 用于处理类别极度不平衡问题,通过降低易分类样本的权重让模型聚焦于难分类样本,该函数已实现但尚未在训练中启用。
完整代码:
import torch
import torch.nn as nn
import torch.nn.functional as F
from scipy.optimize import linear_sum_assignment
def dice_loss(inputs, targets, num_boxes):
"""
计算 Dice Loss,专门用于优化分割掩码
inputs: 预测的mask [N, H, W] (未经过sigmoid)
targets: 真实的mask [N, H, W] (0或1)
"""
inputs = inputs.sigmoid()
inputs = inputs.flatten(1) # 展平为 [N, H*W]
targets = targets.flatten(1)
numerator = 2 * (inputs * targets).sum(1)
denominator = inputs.sum(-1) + targets.sum(-1)
loss = 1 - (numerator + 1) / (denominator + 1) # +1 是平滑项,防止除以0
return loss.sum() / num_boxes
def sigmoid_focal_loss(inputs, targets, num_boxes, alpha: float = 0.25, gamma: float = 2):
"""用于分类的 Focal Loss,解决背景太多导致的正负样本极度不平衡"""
prob = inputs.sigmoid()
ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
p_t = prob * targets + (1 - prob) * (1 - targets)
loss = ce_loss * ((1 - p_t) ** gamma)
if alpha >= 0:
alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
loss = alpha_t * loss
return loss.mean(1).sum() / num_boxes
class HungarianMatcher(nn.Module):
"""
匈牙利匹配器:用于在 50 个预测结果中,找到与真实 Mask 最匹配的那一个
"""
def __init__(self, cost_class: float = 1.0, cost_mask: float = 1.0, cost_dice: float = 1.0):
super().__init__()
self.cost_class = cost_class
self.cost_mask = cost_mask
self.cost_dice = cost_dice
@torch.no_grad()
def forward(self, outputs, targets):
bs, num_queries = outputs["pred_logits"].shape[:2]
# 获取预测的类别和Mask
out_prob = outputs["pred_logits"].flatten(0, 1).softmax(-1) # [batch_size * num_queries, num_classes]
out_mask = outputs["pred_masks"].flatten(0, 1) # [batch_size * num_queries, H, W]
# 获取真实的类别和Mask (这里假定每张图只有1个目标,类别ID为1)
tgt_ids = torch.cat([v["labels"] for v in targets]) # [num_total_targets]
tgt_mask = torch.cat([v["masks"] for v in targets]) # [num_total_targets, H, W]
# 1. 计算类别代价 (Class Cost)
cost_class = -out_prob[:, tgt_ids]
# 2. 计算 Mask 代价 (BCE Cost)
# 为了加速计算,对 Mask 进行下采样降维
out_mask_down = F.interpolate(out_mask.unsqueeze(1), size=(128, 128), mode="bilinear",
align_corners=False).squeeze(1)
tgt_mask_down = F.interpolate(tgt_mask.unsqueeze(1).float(), size=(128, 128), mode="nearest").squeeze(1)
out_mask_flat = out_mask_down.flatten(1) # [batch_size * num_queries, 128*128]
tgt_mask_flat = tgt_mask_down.flatten(1) # [num_total_targets, 128*128]
# BCE 代价 (使用 pos_weight 处理不平衡)
cost_mask = F.binary_cross_entropy_with_logits(
out_mask_flat.unsqueeze(1).expand(-1, tgt_mask_flat.shape[0], -1),
tgt_mask_flat.unsqueeze(0).expand(out_mask_flat.shape[0], -1, -1),
reduction="none"
).mean(-1)
# 3. 计算 Dice 代价
out_mask_sig = out_mask_flat.sigmoid()
numerator = 2 * torch.matmul(out_mask_sig, tgt_mask_flat.t())
denominator = out_mask_sig.sum(-1).unsqueeze(1) + tgt_mask_flat.sum(-1).unsqueeze(0)
cost_dice = 1 - (numerator + 1) / (denominator + 1)
# 4. 汇总总代价矩阵
C = self.cost_class * cost_class + self.cost_mask * cost_mask + self.cost_dice * cost_dice
C = C.view(bs, num_queries, -1).cpu()
# 5. 使用匈牙利算法找出最优匹配
sizes = [len(v["masks"]) for v in targets]
indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(sizes, -1))]
return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices]
class SetCriterion(nn.Module):
"""
计算最终损失的类
基于匹配器的结果,计算命中目标的 Loss,和其余背景的 Loss
"""
def __init__(self, matcher):
super().__init__()
self.matcher = matcher
# 类别0是背景,类别1是前景目标
empty_weight = torch.ones(2)
empty_weight[0] = 0.1 # 降低背景类的惩罚权重,防止网络全预测成背景
self.register_buffer("empty_weight", empty_weight)
def _get_src_permutation_idx(self, indices):
batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
src_idx = torch.cat([src for (src, _) in indices])
return batch_idx, src_idx
def forward(self, outputs, targets):
"""
outputs: 模型的输出字典 (pred_logits, pred_masks)
targets: 真实标签列表,例如 [{'labels': tensor([1]), 'masks': tensor([1, H, W])}, ...]
"""
# 1. 获取匹配索引
indices = self.matcher(outputs, targets)
idx = self._get_src_permutation_idx(indices)
# 2. 计算分类损失 (Cross Entropy)
pred_logits = outputs["pred_logits"]
target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)])
target_classes = torch.full(pred_logits.shape[:2], 0, dtype=torch.int64, device=pred_logits.device) # 默认全是背景(0)
target_classes[idx] = target_classes_o # 把匹配到的位置标为前景(1)
loss_ce = F.cross_entropy(pred_logits.transpose(1, 2), target_classes, self.empty_weight)
# 3. 计算 Mask 损失 (仅对匹配到的 Query 计算)
src_masks = outputs["pred_masks"][idx] # 提取命中的预测 Mask
target_masks = torch.cat([t["masks"][i] for t, (_, i) in zip(targets, indices)]).to(
src_masks.device) # 提取真实的 Mask
num_boxes = src_masks.shape[0]
if num_boxes == 0:
loss_mask = src_masks.sum() * 0 # 防止报错
loss_dice = src_masks.sum() * 0
else:
# 缩放真实Mask到特征图大小,以对齐计算
target_masks = F.interpolate(target_masks.unsqueeze(1).float(), size=src_masks.shape[-2:],
mode="nearest").squeeze(1)
loss_mask = F.binary_cross_entropy_with_logits(src_masks, target_masks, reduction="mean")
loss_dice = dice_loss(src_masks, target_masks, num_boxes)
losses = {
"loss_ce": loss_ce,
"loss_mask": loss_mask,
"loss_dice": loss_dice
}
return losses
5.5 model.py —— 早期版本模型定义
角色: 早期模型版本,用于玩具数据验证。
核心类: RFDETRSegmentation
架构说明: 视觉部分使用ResNet50作为骨干网络,去掉最后两层(全局池化和全连接层),输出2048通道的特征图,通过1x1卷积投影到256维。文本部分使用bert-base-chinese的Tokenizer和Encoder,将文本编码为768维向量,通过线性层投影到256维,然后对所有token的向量做mean-pooling得到全局文本向量。Transformer使用PyTorch内置的nn.Transformer,编码器和解码器各4层,8个注意力头,batch_first模式。Query Embedding是一个可学习的嵌入层,50个Query各256维,加上全局文本向量后作为Transformer解码器的输入。三个预测头:class_head输出2类(背景和前景),bbox_head输出4个坐标值(经过sigmoid归一化到0-1),mask_proj将解码输出投影后与图像特征图做点积生成Mask,最后上采样回原图尺寸。
文本注入方式: BERT全局向量通过mean-pooling后直接加到所有50个Query上,所有Query获得相同的文本信息。
与train_real_data.py中AdvancedRFDETR的区别: 无2D空间位置编码,文本注入方式更简单。
完整代码:
import os
import warnings
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
os.environ["HF_HUB_DISABLE_SYMLINKS_WARNING"] = "1"
warnings.filterwarnings("ignore")
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
from transformers import AutoTokenizer, AutoModel
import transformers
transformers.logging.set_verbosity_error()
class RFDETRSegmentation(nn.Module):
def __init__(self, num_queries=50, hidden_dim=256):
super().__init__()
self.hidden_dim = hidden_dim
self.num_queries = num_queries
# 1. 文本编码器
self.tokenizer = AutoTokenizer.from_pretrained("bert-base-chinese")
self.text_encoder = AutoModel.from_pretrained("bert-base-chinese")
self.text_proj = nn.Linear(768, hidden_dim)
# 2. 图像骨干网络 (ResNet50)
resnet = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
self.backbone = nn.Sequential(*list(resnet.children())[:-2])
self.conv_proj = nn.Conv2d(2048, hidden_dim, 1)
# 3. Transformer
self.query_embed = nn.Embedding(num_queries, hidden_dim)
self.transformer = nn.Transformer(
d_model=hidden_dim, nhead=8, num_encoder_layers=4, num_decoder_layers=4, batch_first=True
)
# 4. 预测头
self.class_head = nn.Linear(hidden_dim, 2)
self.bbox_head = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, 4)
)
# 5. Mask Head
self.mask_proj = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim)
)
def forward(self, images, reasoning_texts):
B = images.shape[0]
device = images.device
text_inputs = self.tokenizer(reasoning_texts, return_tensors="pt", padding=True, truncation=True, max_length=64).to(device)
text_features = self.text_encoder(**text_inputs).last_hidden_state
text_features = self.text_proj(text_features)
text_global = text_features.mean(dim=1).unsqueeze(1)
img_features = self.backbone(images)
img_features = self.conv_proj(img_features)
img_seq = img_features.flatten(2).permute(0, 2, 1)
queries = self.query_embed.weight.unsqueeze(0).repeat(B, 1, 1)
queries = queries + text_global
hs = self.transformer(src=img_seq, tgt=queries)
outputs_class = self.class_head(hs)
outputs_coord = self.bbox_head(hs).sigmoid()
mask_embeddings = self.mask_proj(hs)
outputs_mask = torch.einsum("bnc,bchw->bnhw", mask_embeddings, img_features)
outputs_mask = F.interpolate(outputs_mask, size=images.shape[-2:], mode="bilinear", align_corners=False)
return {
"pred_logits": outputs_class,
"pred_boxes": outputs_coord,
"pred_masks": outputs_mask
}
def freeze_backbones(self):
for param in self.backbone.parameters():
param.requires_grad = False
for param in self.text_encoder.parameters():
param.requires_grad = False
print("已冻结 ResNet50 和 BERT 的参数。")
5.6 train.py —— 玩具数据训练脚本
角色: 玩具数据训练闭环。
功能: 加载 data/annotations.json 玩具数据集,实例化 RFDETRSegmentation 模型,冻结ResNet50和BERT骨干网络以节省显存和加速训练。使用HungarianMatcher和SetCriterion计算损失,AdamW优化器(学习率1e-4),训练20个epoch,batch_size为2,图像尺寸512。每个step打印分类损失、Mask损失和Dice损失的详细数值。训练结束后保存模型权重。
超参数: epochs=20,batch_size=2,lr=1e-4,img_size=512。
损失权重: loss_ce权重1.0,loss_mask权重5.0,loss_dice权重2.0。
输出: reasoning_rfdetr_best.pth
运行命令: python train.py
完整代码:
import torch
from torch.utils.data import DataLoader
import torch.optim as optim
from dataset import ReasoningSegDataset
from model import RFDETRSegmentation
from loss import HungarianMatcher, SetCriterion
def train():
# 1. 基础设置
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"开始训练!使用设备: {device}")
epochs = 20
batch_size = 2 # 显存如果不够可以改为1
# 2. 加载数据
dataset = ReasoningSegDataset(json_path="data/annotations.json", img_size=512)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
print(f"数据加载完成,共有 {len(dataset)} 个样本。")
# 3. 初始化模型
model = RFDETRSegmentation(num_queries=50).to(device)
model.freeze_backbones() # 冻结骨干,加速训练
model.train()
# 4. 初始化匹配器与损失函数
matcher = HungarianMatcher()
criterion = SetCriterion(matcher).to(device) # 在这里加上 .to(device)
# 5. 初始化优化器 (AdamW对Transformer最好)
# 过滤掉被冻结的参数,只优化需要梯度的参数
optimizer = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-4)
# 6. 开始 Epoch 循环
for epoch in range(epochs):
epoch_loss = 0.0
for step, (images, masks, texts) in enumerate(dataloader):
images = images.to(device)
masks = masks.to(device) # shape: [B, 1, 512, 512]
# --- A. 组装 Targets 格式 ---
# SetCriterion 要求的格式是 List[Dict]
targets = []
for i in range(images.shape[0]):
targets.append({
"labels": torch.tensor([1], dtype=torch.int64, device=device), # 前景类别固定为1
"masks": masks[i] # [1, 512, 512]
})
# --- B. 前向传播 ---
outputs = model(images, texts)
# --- C. 计算损失 ---
loss_dict = criterion(outputs, targets)
# 损失加权 (这是DETR官方推荐的权重比例)
weight_dict = {"loss_ce": 1.0, "loss_mask": 5.0, "loss_dice": 2.0}
loss = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys())
# --- D. 反向传播与优化 ---
optimizer.zero_grad()
loss.backward()
# 梯度裁剪 (防止Transformer训练初期梯度爆炸)
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.1)
optimizer.step()
epoch_loss += loss.item()
# 打印每个 Step 的 Loss 细节
if step % 5 == 0:
print(f"Epoch [{epoch + 1}/{epochs}] Step [{step}/{len(dataloader)}] "
f"Total Loss: {loss.item():.4f} | "
f"CE: {loss_dict['loss_ce'].item():.4f}, "
f"Mask: {loss_dict['loss_mask'].item():.4f}, "
f"Dice: {loss_dict['loss_dice'].item():.4f}")
# 计算 Epoch 的平均 Loss
avg_loss = epoch_loss / len(dataloader)
print(f"Epoch [{epoch + 1}/{epochs}] 结束,平均 Loss: {avg_loss:.4f}\n")
# 7. 训练结束,保存模型权重
torch.save(model.state_dict(), "reasoning_rfdetr_best.pth")
print("训练完成 模型权重已保存为 reasoning_rfdetr_best.pth")
if __name__ == "__main__":
train()
5.7 inference.py —— 玩具数据推理脚本
角色: 玩具数据推理验证。
功能: 加载训练好的 reasoning_rfdetr_best.pth 权重,取玩具数据集的第0张图进行前向推理。对50个Query的预测logits做softmax得到前景概率,取置信度最高的Query对应的Mask,用sigmoid转换后以0.5为阈值二值化。将原图逆归一化后与预测Mask叠加可视化,左侧显示原图,右侧显示红色半透明Mask叠加结果。
输出: inference_result.png
运行命令: python inference.py
完整代码:
import torch
import json
import matplotlib.pyplot as plt
import torchvision.transforms as T
from PIL import Image
import numpy as np
# 导入我们写的模型
from model import RFDETRSegmentation
def run_inference():
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"正在使用设备: {device}")
# 1. 加载标注文件,随便挑一张图来测试 (比如第0张)
with open("data/annotations.json", "r", encoding="utf-8") as f:
data = json.load(f)
test_sample = data[0] # 测试第一张图
img_path = test_sample["image_path"]
text_prompt = test_sample["text"]
print(f"\n[1] 测试图片: {img_path}")
print(f"[2] 推理指令: {text_prompt}")
# 2. 图像预处理
image = Image.open(img_path).convert("RGB")
transform = T.Compose([
T.Resize((512, 512)),
T.ToTensor(),
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
image_tensor = transform(image).unsqueeze(0).to(device) # 增加 batch 维度
# 3. 初始化模型并加载权重
print("[3] 正在加载训练好的模型权重...")
model = RFDETRSegmentation(num_queries=50).to(device)
# 加载权重 (严格匹配)
model.load_state_dict(torch.load("reasoning_rfdetr_best.pth", map_location=device))
model.eval() # 切换到推理模式
# 4. 执行前向推理
print("[4] 模型思考与分割中...")
with torch.no_grad():
outputs = model(image_tensor, [text_prompt])
# 5. 解析预测结果
# outputs["pred_logits"] 形状: [1, 50, 2]
# outputs["pred_masks"] 形状: [1, 50, 512, 512]
# 算出 50 个 Query 中,预测为前景(类别1)的概率
probs = outputs["pred_logits"][0].softmax(dim=-1)
foreground_scores = probs[:, 1]
# 找到得分最高的那个 Query
best_query_idx = foreground_scores.argmax().item()
best_score = foreground_scores[best_query_idx].item()
print(f"[5] 命中目标!最高置信度: {best_score:.4f} (来自 Query #{best_query_idx})")
# 提取对应的 Mask,使用 Sigmoid 将其转换到 0~1,并二值化 (阈值0.5)
best_mask = outputs["pred_masks"][0, best_query_idx].sigmoid()
binary_mask = (best_mask > 0.5).cpu().numpy()
# 6. 结果可视化
# 还原原图用于展示
img_show = image_tensor[0].cpu().permute(1, 2, 0).numpy()
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
img_show = std * img_show + mean
img_show = np.clip(img_show, 0, 1)
plt.figure(figsize=(10, 5))
# 左边:原图
plt.subplot(1, 2, 1)
plt.title("Original Image", fontsize=12)
plt.imshow(img_show)
plt.axis("off")
# 右边:预测的Mask叠加在原图上
plt.subplot(1, 2, 2)
plt.title(f"Predicted Mask (Score: {best_score:.2f})", fontsize=12)
plt.imshow(img_show)
# 用半透明的红色(Reds)覆盖在原图上
plt.imshow(binary_mask, cmap='Reds', alpha=0.6)
plt.axis("off")
plt.tight_layout()
plt.savefig("inference_result.png")
print("\n推理完成 可视化结果已保存为 inference_result.png")
if __name__ == "__main__":
run_inference()
5.8 auto_labeling_pipeline.py —— 自动打标流水线
角色: 最核心的数据生产引擎,实现零人工标注。
功能: 遍历 raw_waiting_room_photos/ 目录中的所有原始图片,对每张图片和每条推理指令执行三级流水线,最终生成可训练的数据集。
三级流水线详解:
第一级:MLLM推理(step1_mllm_reasoning方法)。 将图片编码为base64格式,连同中文推理指令一起发送给GPT-4o API。系统提示词要求大模型只输出极短的英文名词短语,不能输出完整句子。大模型返回英文实体词后,代码进行防杠精过滤:如果输出包含"None"、“unable”、"cannot"等词,或者输出超过15个词,则判定为无法识别,跳过该条数据。这是保证数据集质量的核心防线。
第二级:Grounding DINO定位(step2_grounding_dino方法)。 将图像和英文实体词送入Grounding DINO模型进行开放世界目标检测。使用box_threshold=0.35和text_threshold=0.25过滤低置信度检测结果。代码包含try-except兼容机制,处理不同版本transformers库的API差异。输出为Bounding Box列表和对应的置信度分数。
第三级:SAM分割(step3_sam_segmentation方法)。 将原图转换为NumPy数组送入SAM的predictor,将DINO输出的Box坐标通过SAM的transform模块转换到模型坐标系,调用predict_torch进行分割。多个物体的Mask合并为一个整体Mask,保存为PNG文件。
最终输出: 在 waiting_room_reasoning_dataset/ 目录下生成 images/ 子目录(去重后的原图)、masks/ 子目录(合并后的Mask)和 annotations.json 标注文件。每条标注包含 image_path、mask_path、text(中文原指令)、mllm_visual_prompt(MLLM输出的英文实体词)、bboxes(DINO检测到的所有Bounding Box坐标)。
运行命令: python auto_labeling_pipeline.py
注意事项: 运行前需在脚本底部配置API_KEY和BASE_URL,需将原始图片放入 raw_waiting_room_photos/ 目录,需在 reasoning_tasks 列表中定义推理指令。
完整代码:
import os
import cv2
import json
import torch
import base64
import numpy as np
from PIL import Image
from tqdm import tqdm
from openai import OpenAI
from huggingface_hub import snapshot_download
import warnings
warnings.filterwarnings("ignore")
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "0"
from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection
from segment_anything import sam_model_registry, SamPredictor
class AutoLabelingPipeline:
def __init__(self, api_key, base_url, sam_checkpoint="weights/sam_vit_h_4b8939.pth"):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"初始化自动化打标签流水线,使用设备: {self.device}")
self.client = OpenAI(api_key=api_key, base_url=base_url)
self.model_name = "gpt-4o"
print("正在加载 Grounding DINO...")
local_dino_path = "./models/grounding-dino-base"
if not os.path.exists(local_dino_path) or not os.listdir(local_dino_path):
print(f"首次运行,正在从国内镜像下载模型...")
os.makedirs(local_dino_path, exist_ok=True)
snapshot_download(
repo_id="IDEA-Research/grounding-dino-base",
local_dir=local_dino_path,
local_dir_use_symlinks=False,
resume_download=True,
tqdm_class=tqdm,
endpoint="https://hf-mirror.com"
)
self.dino_processor = AutoProcessor.from_pretrained(local_dino_path)
self.dino_model = AutoModelForZeroShotObjectDetection.from_pretrained(local_dino_path).to(self.device)
print("正在加载 SAM (Segment Anything)...")
sam = sam_model_registry["vit_h"](checkpoint=sam_checkpoint)
sam.to(device=self.device)
self.sam_predictor = SamPredictor(sam)
def encode_image(self, image_path):
with open(image_path, "rb") as f:
return base64.b64encode(f.read()).decode('utf-8')
def step1_mllm_reasoning(self, image_path, reasoning_prompt):
base64_image = self.encode_image(image_path)
system_prompt = """
你是一个冰冷的视觉实体提取机器。用户会给你图和任务。
【强制规则】:
1. 只能输出极短的英文名词短语!
2. 如果你认为图里没有符合的物体,或者无法判断,请严格只输出一个单词:"None"
3. 绝对不要输出任何完整的句子(如 I am unable to, The image shows 等)。
正确示例:"blue sofa, wooden table"
错误示例:"I can see a blue sofa and a wooden table."
"""
response = self.client.chat.completions.create(
model=self.model_name,
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": [
{"type": "text", "text": f"推理任务:{reasoning_prompt}"},
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"}}
]}
],
temperature=0.1
)
entities_str = response.choices[0].message.content.strip()
print(f"MLLM 思考结果: {entities_str}")
# 防杠精机制:过滤大模型的废话
lower_str = entities_str.lower()
if "none" in lower_str or "unable" in lower_str or "cannot" in lower_str or len(entities_str.split()) > 15:
print("MLLM 无法识别或未按格式输出,跳过该指令。")
return ""
return entities_str.replace(",", ".") + "."
def step2_grounding_dino(self, image_path, text_prompt, box_threshold=0.35, text_threshold=0.25):
image = Image.open(image_path).convert("RGB")
inputs = self.dino_processor(images=image, text=text_prompt, return_tensors="pt").to(self.device)
with torch.no_grad():
outputs = self.dino_model(**inputs)
# 修复 Bug:去掉 box_threshold 参数,使用通用写法后手动过滤
try:
results = self.dino_processor.post_process_grounded_object_detection(
outputs,
inputs.input_ids,
box_threshold=box_threshold,
text_threshold=text_threshold,
target_sizes=[image.size[::-1]]
)[0]
boxes = results["boxes"].cpu().numpy()
scores = results["scores"].cpu().numpy()
labels = results["labels"]
except TypeError:
# 兼容老版本的 transformers 库
results = self.dino_processor.post_process_grounded_object_detection(
outputs,
inputs.input_ids,
target_sizes=[image.size[::-1]]
)[0]
# 手动执行阈值过滤
keep_idx = results["scores"] > box_threshold
boxes = results["boxes"][keep_idx].cpu().numpy()
scores = results["scores"][keep_idx].cpu().numpy()
labels = [results["labels"][i] for i, keep in enumerate(keep_idx) if keep]
print(f" DINO 找到 {len(boxes)} 个目标框。")
return boxes, scores, labels, image
def step3_sam_segmentation(self, image_rgb, boxes):
if len(boxes) == 0:
return []
image_np = np.array(image_rgb)
self.sam_predictor.set_image(image_np)
input_boxes = torch.tensor(boxes, device=self.sam_predictor.device)
transformed_boxes = self.sam_predictor.transform.apply_boxes_torch(input_boxes, image_np.shape[:2])
masks, _, _ = self.sam_predictor.predict_torch(
point_coords=None,
point_labels=None,
boxes=transformed_boxes,
multimask_output=False
)
final_masks = masks.squeeze(1).cpu().numpy()
print(f"SAM 成功生成 {len(final_masks)} 个高精度 Mask。")
return final_masks
def process_dataset(self, image_dir, output_dir, reasoning_prompts):
os.makedirs(output_dir, exist_ok=True)
os.makedirs(os.path.join(output_dir, "images"), exist_ok=True)
os.makedirs(os.path.join(output_dir, "masks"), exist_ok=True)
dataset_annotations = []
image_files = [f for f in os.listdir(image_dir) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
print(f"开始处理数据集,共 {len(image_files)} 张图片...")
for img_idx, img_name in enumerate(tqdm(image_files)):
img_path = os.path.join(image_dir, img_name)
for prompt_idx, prompt in enumerate(reasoning_prompts):
print(f"\n处理 [{img_name}] - 指令: {prompt}")
try:
dino_prompt = self.step1_mllm_reasoning(img_path, prompt)
if not dino_prompt:
continue
boxes, scores, _, image_rgb = self.step2_grounding_dino(img_path, dino_prompt)
if len(boxes) == 0:
continue
masks = self.step3_sam_segmentation(image_rgb, boxes)
if len(masks) == 0:
continue
new_img_name = f"img_{img_idx:04d}.jpg"
new_img_path = os.path.join(output_dir, "images", new_img_name)
if not os.path.exists(new_img_path):
image_rgb.save(new_img_path)
combined_mask = np.zeros(masks[0].shape, dtype=np.uint8)
for mask in masks:
combined_mask[mask] = 255
mask_name = f"mask_{img_idx:04d}_p{prompt_idx}.png"
mask_path = os.path.join(output_dir, "masks", mask_name)
cv2.imwrite(mask_path, combined_mask)
dataset_annotations.append({
"image_path": new_img_path,
"mask_path": mask_path,
"text": prompt,
"mllm_visual_prompt": dino_prompt,
"bboxes": boxes.tolist()
})
except Exception as e:
print(f"处理 {img_name} 时发生错误: {e}")
continue
with open(os.path.join(output_dir, "annotations.json"), "w", encoding="utf-8") as f:
json.dump(dataset_annotations, f, ensure_ascii=False, indent=4)
print(f"数据集生成完毕 成功打标 {len(dataset_annotations)} 条数据,保存在 {output_dir}/ 目录下。")
if __name__ == "__main__":
API_KEY = "your-api-key-here"
BASE_URL = "your-api-base-url-here"
pipeline = AutoLabelingPipeline(api_key=API_KEY, base_url=BASE_URL)
INPUT_IMAGE_DIR = "raw_waiting_room_photos/"
OUTPUT_DATASET_DIR = "waiting_room_reasoning_dataset/"
reasoning_tasks = [
"分割出所有可以移动的软装家具(用来把它们一键消除,看空房间的效果)",
"找出所有的承重墙和不可拆卸的结构",
"标出房间内下午会被阳光直射的地板区域(用来决定哪里放怕晒的植物)"
]
pipeline.process_dataset(
image_dir=INPUT_IMAGE_DIR,
output_dir=OUTPUT_DATASET_DIR,
reasoning_prompts=reasoning_tasks
)
5.9 train_real_data.py —— 真实数据训练脚本
角色: 核心训练脚本,定义终极版模型并执行训练闭环。
核心类: AdvancedRFDETR。相比model.py中的RFDETRSegmentation,增加了2D空间位置编码。
AdvancedRFDETR架构详解:
视觉分支: ResNet50提取特征图,输出尺寸为[B,2048,H/32,W/32],通过1x1卷积投影到256维。然后注入2D空间位置编码:row_embed和col_embed是两个可学习参数,分别编码行位置和列位置信息,拼接后加到图像特征上,让Transformer具备空间方向感。最后将特征图展平为序列[B,256,H*W/1024]作为Transformer编码器的输入。
文本分支: 使用bert-base-chinese的Tokenizer和Encoder,将文本编码后通过线性层投影到256维,对所有token做mean-pooling得到全局文本向量[B,1,256]。
跨模态融合: 50个可学习的Query Embedding加上全局文本向量后作为Transformer解码器的输入。Transformer编码器处理图像序列,解码器通过交叉注意力机制让每个Query关注图像中与文本相关的区域。
预测头: class_head输出[B,50,2]的分类logits,bbox_head输出[B,50,4]的归一化坐标,mask_proj将解码输出投影后与图像特征图做einsum点积生成[B,50,H/32,W/32]的Mask,最后通过双线性插值上采样回[B,50,512,512]。
训练配置: 加载 waiting_room_reasoning_dataset/annotations.json 真实数据集,冻结ResNet50和BERT骨干网络,AdamW优化器(学习率1e-4),训练80个epoch,batch_size为2,图像尺寸512。损失权重为loss_ce权重1.0、loss_mask权重5.0、loss_dice权重2.0。使用梯度裁剪(max_norm=0.1)防止Transformer训练初期的梯度爆炸。
输出: advanced_rfdetr_waiting_room.pth
运行命令: python train_real_data.py
完整代码:
import os
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
from PIL import Image
import torch.optim as optim
import warnings
from transformers import AutoTokenizer, AutoModel
from huggingface_hub import snapshot_download
warnings.filterwarnings("ignore")
# 强制使用国内镜像站
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
# 注释掉离线模式,允许首次完整下载
# os.environ["TRANSFORMERS_OFFLINE"] = "1"
# 引入之前的 Loss 函数
from loss import HungarianMatcher, SetCriterion
# 1. 适配新数据集的 Dataset
class RealReasoningDataset(Dataset):
def __init__(self, json_path, img_size=512):
with open(json_path, 'r', encoding='utf-8') as f:
self.annotations = json.load(f)
self.img_size = img_size
self.img_transform = T.Compose([
T.Resize((img_size, img_size)),
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
self.mask_transform = T.Compose([
T.Resize((img_size, img_size), interpolation=T.InterpolationMode.NEAREST),
T.ToTensor()
])
def __len__(self):
return len(self.annotations)
def __getitem__(self, idx):
anno = self.annotations[idx]
image = Image.open(anno['image_path']).convert('RGB')
mask = Image.open(anno['mask_path']).convert('L')
text = anno['text']
image_tensor = self.img_transform(image)
mask_tensor = self.mask_transform(mask)
mask_tensor = (mask_tensor > 0.5).float() # 二值化
return image_tensor, mask_tensor, text
# 2. RF-DETR (加入 2D 空间位置编码)
class AdvancedRFDETR(nn.Module):
def __init__(self, num_queries=50, hidden_dim=256):
super().__init__()
self.hidden_dim = hidden_dim
self.num_queries = num_queries
# 本地化加载 BERT
local_bert_path = "./models/bert-base-chinese"
if not os.path.exists(local_bert_path) or not os.listdir(local_bert_path):
print("从国内镜像补全 BERT 模型 ...")
os.makedirs(local_bert_path, exist_ok=True)
snapshot_download(
repo_id="bert-base-chinese",
local_dir=local_bert_path,
local_dir_use_symlinks=False,
resume_download=True,
endpoint="https://hf-mirror.com"
)
print(" BERT 下载完成!")
print(f"从本地加载 BERT: {local_bert_path}")
self.tokenizer = AutoTokenizer.from_pretrained(local_bert_path)
self.text_encoder = AutoModel.from_pretrained(local_bert_path)
self.text_proj = nn.Linear(768, hidden_dim)
# 图像特征
resnet = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
self.backbone = nn.Sequential(*list(resnet.children())[:-2])
self.conv_proj = nn.Conv2d(2048, hidden_dim, 1)
# 二维相对位置编码
self.row_embed = nn.Parameter(torch.rand(50, hidden_dim // 2))
self.col_embed = nn.Parameter(torch.rand(50, hidden_dim // 2))
# Transformer
self.query_embed = nn.Embedding(num_queries, hidden_dim)
self.transformer = nn.Transformer(
d_model=hidden_dim, nhead=8, num_encoder_layers=4, num_decoder_layers=4, batch_first=True
)
# 预测头
self.class_head = nn.Linear(hidden_dim, 2)
self.bbox_head = nn.Sequential(nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, 4))
self.mask_proj = nn.Sequential(nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim))
def forward(self, images, reasoning_texts):
B = images.shape[0]
device = images.device
text_inputs = self.tokenizer(reasoning_texts, return_tensors="pt", padding=True, truncation=True,
max_length=64).to(device)
text_features = self.text_proj(self.text_encoder(**text_inputs).last_hidden_state)
text_global = text_features.mean(dim=1).unsqueeze(1)
img_features = self.backbone(images)
H, W = img_features.shape[-2:]
img_features = self.conv_proj(img_features)
pos_x = self.col_embed[:W].unsqueeze(0).repeat(H, 1, 1)
pos_y = self.row_embed[:H].unsqueeze(1).repeat(1, W, 1)
pos_encoding = torch.cat([pos_y, pos_x], dim=-1).permute(2, 0, 1).unsqueeze(0).repeat(B, 1, 1, 1)
img_seq = (img_features + pos_encoding).flatten(2).permute(0, 2, 1)
queries = self.query_embed.weight.unsqueeze(0).repeat(B, 1, 1) + text_global
hs = self.transformer(src=img_seq, tgt=queries)
outputs_class = self.class_head(hs)
outputs_coord = self.bbox_head(hs).sigmoid()
mask_embeddings = self.mask_proj(hs)
outputs_mask = torch.einsum("bnc,bchw->bnhw", mask_embeddings, img_features)
outputs_mask = F.interpolate(outputs_mask, size=images.shape[-2:], mode="bilinear", align_corners=False)
return {"pred_logits": outputs_class, "pred_boxes": outputs_coord, "pred_masks": outputs_mask}
def freeze_backbones(self):
for param in self.backbone.parameters(): param.requires_grad = False
for param in self.text_encoder.parameters(): param.requires_grad = False
# 3. 真实数据训练闭环
def train_real_data():
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"开始在真实数据集上训练!使用设备: {device}")
json_path = "waiting_room_reasoning_dataset/annotations.json"
dataset = RealReasoningDataset(json_path=json_path, img_size=512)
dataloader = DataLoader(dataset, batch_size=2, shuffle=True)
print(f"数据加载完成,共有 {len(dataset)} 个真实样本。")
model = AdvancedRFDETR(num_queries=50).to(device)
model.freeze_backbones()
model.train()
matcher = HungarianMatcher()
criterion = SetCriterion(matcher).to(device)
optimizer = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-4)
epochs = 80
for epoch in range(epochs):
epoch_loss = 0.0
for step, (images, masks, texts) in enumerate(dataloader):
images, masks = images.to(device), masks.to(device)
targets = [{"labels": torch.tensor([1], dtype=torch.int64, device=device), "masks": masks[i]} for i in
range(images.shape[0])]
outputs = model(images, texts)
loss_dict = criterion(outputs, targets)
loss = loss_dict['loss_ce'] * 1.0 + loss_dict['loss_mask'] * 5.0 + loss_dict['loss_dice'] * 2.0
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.1)
optimizer.step()
epoch_loss += loss.item()
print(f"Epoch [{epoch + 1}/{epochs}] 平均 Loss: {epoch_loss / len(dataloader):.4f}")
torch.save(model.state_dict(), "advanced_rfdetr_waiting_room.pth")
print("真实场景模型训练完成 权重已保存。")
if __name__ == "__main__":
train_real_data()
5.10 inference_real_data.py —— 真实数据推理脚本
角色: 最终验收推理脚本。
功能: 加载训练好的 advanced_rfdetr_waiting_room.pth 权重,取真实数据集的第13张图(可修改索引)进行前向推理。对50个Query的预测logits做softmax得到前景概率,取置信度最高的Query对应的Mask,用sigmoid转换后以0.5为阈值二值化。将原图逆归一化后与预测Mask叠加可视化,使用jet伪彩色映射让分割区域更加醒目。
输出: inference_real_result.png
运行命令: python inference_real_data.py
完整代码:
import os
import torch
import json
import matplotlib.pyplot as plt
import torchvision.transforms as T
from PIL import Image
import numpy as np
import warnings
warnings.filterwarnings("ignore")
# 直接从训练脚本中导入带有空间位置编码的终极版模型
from train_real_data import AdvancedRFDETR
def run_real_inference():
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"正在使用设备: {device}")
# 1. 加载我们自动生成的标注文件,挑一张图来测试 (比如第 0 张)
dataset_dir = "waiting_room_reasoning_dataset"
with open(os.path.join(dataset_dir, "annotations.json"), "r", encoding="utf-8") as f:
data = json.load(f)
# 修改这里的索引 (0 到 12),看看不同图片的预测效果
test_sample = data[13]
img_path = test_sample["image_path"]
# 既可以使用原来的中文指令,也可以试试 DINO 提取的英文特征
text_prompt = test_sample["text"]
print(f"\n[1] 测试图片: {img_path}")
print(f"[2] 推理指令: {text_prompt}")
# 2. 图像预处理
image = Image.open(img_path).convert("RGB")
transform = T.Compose([
T.Resize((512, 512)),
T.ToTensor(),
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
image_tensor = transform(image).unsqueeze(0).to(device)
# 3. 初始化模型并加载权重
print("[3] 正在加载候诊室专属模型权重...")
model = AdvancedRFDETR(num_queries=50).to(device)
model.load_state_dict(torch.load("advanced_rfdetr_waiting_room.pth", map_location=device))
model.eval() # 切换到推理模式
# 4. 执行前向推理
print("[4] 模型思考与分割中...")
with torch.no_grad():
outputs = model(image_tensor, [text_prompt])
# 5. 解析预测结果
probs = outputs["pred_logits"][0].softmax(dim=-1)
foreground_scores = probs[:, 1]
best_query_idx = foreground_scores.argmax().item()
best_score = foreground_scores[best_query_idx].item()
print(f"[5] 命中目标!最高置信度: {best_score:.4f} (来自 Query #{best_query_idx})")
# 提取对应的 Mask
best_mask = outputs["pred_masks"][0, best_query_idx].sigmoid()
binary_mask = (best_mask > 0.5).cpu().numpy()
# 6. 结果可视化
img_show = image_tensor[0].cpu().permute(1, 2, 0).numpy()
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
img_show = std * img_show + mean
img_show = np.clip(img_show, 0, 1)
# 画图
plt.figure(figsize=(12, 6))
# 原图
plt.subplot(1, 2, 1)
plt.title("Original Waiting Room", fontsize=12)
plt.imshow(img_show)
plt.axis("off")
# 预测的 Mask 叠加
plt.subplot(1, 2, 2)
plt.title(f"Predicted Mask (Score: {best_score:.2f})", fontsize=12)
plt.imshow(img_show)
plt.imshow(binary_mask, cmap='jet', alpha=0.5) # 使用 jet 伪彩色,看起来更清晰
plt.axis("off")
plt.tight_layout()
plt.savefig("inference_real_result.png", bbox_inches='tight', dpi=150)
print("\n推理完成 可视化结果已保存为 inference_real_result.png")
if __name__ == "__main__":
run_real_inference()
六、执行顺序
整个项目的执行分为两个阶段,建议先跑通玩具数据验证链,再跑真实数据生产链。
阶段零:环境准备(一次性操作)
安装Python依赖包,包括torch、torchvision、transformers、opencv-python、pillow、openai、tqdm、matplotlib、huggingface_hub、scipy。安装Segment Anything官方库。手动下载sam_vit_h_4b8939.pth权重文件放入weights/目录。首次运行脚本时会自动从国内镜像下载grounding-dino-base和bert-base-chinese到models/目录。
阶段A:玩具数据验证链(快速验证流水线是否跑通)
第一步: 运行 python generate_toy_data.py,生成20张几何图形假图和对应的Mask,保存在data/目录。
第二步: dataset.py被其他脚本导入使用,无需单独执行。
第三步: 运行 python test_dataloader.py,验证DataLoader输出的维度是否正确,可视化图片和Mask,确认数据管道无误。
第四步: 运行 python train.py,在玩具数据上训练20个epoch,生成reasoning_rfdetr_best.pth权重文件。
第五步: 运行 python inference.py,加载权重对玩具数据进行推理,生成inference_result.png可视化结果。至此玩具数据闭环完成,证明数据管道、模型、损失函数、推理全链路跑通。
阶段B:真实数据生产链
第六步: 运行 python auto_labeling_pipeline.py。运行前需在脚本底部配置API_KEY和BASE_URL,将原始图片放入raw_waiting_room_photos/目录,在reasoning_tasks列表中定义推理指令。脚本遍历所有原始图片,对每张图执行MLLM推理、DINO定位、SAM分割三级流水线,生成waiting_room_reasoning_dataset/数据集。这一步最耗时,60张图乘以3条指令最多180次API调用。
第七步: 运行 python train_real_data.py,在真实数据上训练80个epoch,生成advanced_rfdetr_waiting_room.pth权重文件。
第八步: 运行 python inference_real_data.py,加载权重对真实数据进行推理,生成inference_real_result.png可视化结果。至此真实数据闭环完成。
脚本间的导入依赖关系
train.py导入dataset.py(ReasoningSegDataset类)、model.py(RFDETRSegmentation类)、loss.py(HungarianMatcher和SetCriterion类)test_dataloader.py导入dataset.pyinference.py导入model.pytrain_real_data.py导入loss.py,并在文件内部定义了RealReasoningDataset类和AdvancedRFDETR类inference_real_data.py从train_real_data.py导入 AdvancedRFDETR类auto_labeling_pipeline.py和generate_toy_data.py独立运行,无项目内部依赖
七、环境配置
7.1 硬件要求
| 项目 | 最低要求 | 推荐配置 |
|---|---|---|
| GPU | NVIDIA GPU, 8GB显存 | 12GB+显存 |
| CUDA | 11.8+ | 12.0+ |
| 内存 | 16GB | 32GB |
| 硬盘 | 10GB(含模型权重) | 20GB |
7.2 软件依赖安装
第一步,安装PyTorch(CUDA 11.8版本):
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
第二步,安装基础依赖(使用清华镜像加速):
pip install transformers opencv-python pillow openai tqdm matplotlib huggingface_hub scipy -i https://pypi.tuna.tsinghua.edu.cn/simple
第三步,安装Segment Anything官方库:
pip install git+https://github.com/facebookresearch/segment-anything.git
7.3 模型权重下载
| 权重 | 大小 | 下载方式 |
|---|---|---|
sam_vit_h_4b8939.pth |
~2.4GB | 手动下载 → weights/ |
grounding-dino-base |
~700MB | 首次运行自动下载(国内镜像) |
bert-base-chinese |
~400MB | 首次运行自动下载(国内镜像) |
7.4 API配置
在 auto_labeling_pipeline.py 脚本底部修改 API_KEY 和 BASE_URL 变量。API_KEY填写你的大模型API密钥,BASE_URL填写兼容OpenAI格式的API地址。
八、快速开始
最简流程(跑通真实数据只需四步)
第一步: 将你的原始图片放入 raw_waiting_room_photos/ 目录。
第二步: 编辑 auto_labeling_pipeline.py 底部的 API_KEY 和 BASE_URL,然后运行:
python auto_labeling_pipeline.py
第三步: 运行训练:
python train_real_data.py
第四步: 运行推理查看效果:
python inference_real_data.py
调试流程(用玩具数据快速验证)
依次运行:
python generate_toy_data.py
python test_dataloader.py
python train.py
python inference.py
九、核心算法解析
9.1 匈牙利匹配
DETR架构的核心挑战是模型输出50个Query预测,但每张图通常只有1个真实目标,需要确定哪个预测对应真实目标来计算损失。
解决方法是构建代价矩阵并使用匈牙利算法找最优匹配。代价矩阵由三部分组成:类别代价(预测类别概率的负对数)、BCE代价(预测Mask与真实Mask的二值交叉熵)、Dice代价(1减去Dice系数)。三个代价加权求和得到总代价矩阵,然后使用scipy的linear_sum_assignment算法找出代价最小的Query作为正样本匹配。匹配成功的Query计算分类损失和分割损失,其余49个Query标记为背景并计算背景惩罚损失(背景权重降低到0.1)。
9.2 Grounding DINO与SAM的分工
Grounding DINO和SAM虽然都涉及图像中的目标定位,但分工完全不同,是串联互补关系而非冗余。
| 维度 | Grounding DINO | SAM |
|---|---|---|
| 输出 | 矩形Bounding Box | 像素级Mask |
| 精度 | 目标级(粗粒度) | 像素级(细粒度) |
| 输入 | 图像 + 文本 | 图像 + Box/Point提示 |
| 角色 | “在哪里”(语义定位) | “精确形状”(像素分割) |
| 能否单独使用 | 能,但只有粗糙矩形框 | 不能,需要提示告诉它切哪里 |
为什么两个都要用: 只用DINO只能得到粗糙的矩形框,无法做像素级分割。只用SAM则不知道要分割哪个物体,需要DINO提供Box提示。二者串联后,DINO负责开放世界理解和粗定位,SAM负责闭集精细分割,形成完整的"理解→定位→分割"流水线。
9.3 2D空间位置编码
Transformer本身不具备空间位置感知能力,需要显式注入位置信息。AdvancedRFDETR使用可学习的2D位置编码:row_embed编码行位置(50个位置各128维),col_embed编码列位置(50个位置各128维)。对于特征图上的每个位置,将其行编码和列编码拼接得到256维的位置向量,加到图像特征上。这样Transformer就能区分"左上角"和"右下角"的特征,具备空间方向感。
9.4 损失函数组合
训练使用三种损失的加权组合:
| 损失 | 权重 | 作用 |
|---|---|---|
loss_ce (交叉熵) |
1.0 | 分类:前景 vs 背景 |
loss_mask (BCE) |
5.0 | 像素级二分类 |
loss_dice (Dice) |
2.0 | 优化Mask边缘重叠率 |
十、效果展示
| 文件 | 说明 |
|---|---|
dataloader_test_result.png |
DataLoader输出的图像和Mask并排可视化,用于验证数据管道正确性 |
inference_result.png |
玩具数据推理结果,展示模型对几何图形的分割效果 |
inference_real_result.png |
真实候诊室场景推理结果,展示模型在实际场景中的分割效果 |
十一、已知问题与改进方向
11.1 当前主要问题
问题一(最高优先级):训练数据量不足。 当前数据集约180条标注数据,而DETR类模型通常需要数千到数万样本才能充分收敛。数据量少导致模型容易过拟合训练集,泛化能力差。
问题二(最高优先级):Focal Loss已实现但未启用。 sigmoid_focal_loss函数已在loss.py中完整实现,但SetCriterion中仍使用普通的交叉熵损失。50个Query中只有1个是前景,类别极度不平衡,Focal Loss能有效缓解此问题。
问题三(高优先级):文本注入方式过于简单。 当前将BERT的全局向量通过mean-pooling后直接加到所有50个Query上,所有Query获得完全相同的文本信息,丧失了Query多样性。更好的做法是使用Cross-Attention让每个Query自主关注文本的不同部分。
问题四(高优先级):多目标场景强制单标签。 当前代码假设每张图只有1个目标(labels固定为[1]),但实际数据中一张图可能包含多个物体(沙发、茶几、椅子、绿植等)。合并后的Mask包含多个不连通区域,匈牙利匹配只分配1个Query来拟合这个复杂组合,非常困难。
问题五(中优先级):位置编码使用随机初始化。 row_embed和col_embed使用torch.rand均匀随机初始化,对于位置编码来说远不如正弦编码或学习的正弦初始化稳定,加上数据量少,这些参数难以收敛到有意义的值。
问题六(中优先级):Mask分辨率低。 图像特征图来自ResNet50最后一层,stride为32,512输入对应16x16的特征图。虽然通过双线性插值上采样回原尺寸,但细节已经丢失,小物体在特征图上可能只占1到2个像素。
问题七(中优先级):标签链路误差累积。 训练标签经过GPT-4o到Grounding DINO到SAM三级传递,每一环都可能引入误差(MLLM幻觉、DINO漏检误检、SAM边界噪声),这些误差会固化进训练标签中。
11.2 改进方向
- 数据扩充: 增加原始图片数量,目标将训练样本从180条提升到500条以上。添加数据增强模块,包括图像翻转、色彩扰动、随机裁剪,以及文本同义改写。
- 模型改进: 将文本注入方式从简单的mean-pooling加法改为Cross-Attention机制。引入FPN多尺度特征融合,让Mask Head能利用高分辨率特征。将位置编码从随机初始化改为正弦编码初始化。
- 损失优化: 在SetCriterion中启用已实现的
sigmoid_focal_loss替换普通交叉熵损失。修改targets构建逻辑,支持一张图多个独立Mask的多目标匹配。 - 训练监控: 划分训练集和验证集,每个epoch计算验证集上的mIoU指标。添加TensorBoard日志记录,跟踪损失曲线和指标变化,防止过拟合。
十二、技术栈
| 类别 | 技术 |
|---|---|
| 深度学习框架 | PyTorch, torchvision |
| 预训练模型 | ResNet50, BERT (bert-base-chinese) |
| 开放世界检测 | Grounding DINO (transformers) |
| 图像分割 | Segment Anything (SAM, vit_h) |
| 多模态大模型 | GPT-4o (OpenAI API) |
| 优化算法 | AdamW, 匈牙利算法 (scipy) |
| 数据处理 | OpenCV, PIL, NumPy |
| 可视化 | Matplotlib |
License
本项目仅供学习和研究使用。
致谢
更多推荐



所有评论(0)