SAM2(Segment Anything 2)模型昇腾适配训练
检查环境
1、SSH登录机器后,检查NPU设备状态。运行如下命令,返回NPU设备信息。
npu-smi info # 在每个实例节点上运行此命令可以看到NPU卡状态
npu-smi info -l | grep Total # 在每个实例节点上运行此命令可以看到总卡数,用来确认对应卡数已经挂载
npu-smi info -t board -i 1 | egrep -i "software|firmware" #查看驱动和固件版本
确保NPU设备正常安装
2、 安装CANN包
安装前需确保已具备Python环境及pip3,当前CANN支持Python3.7.x至3.11.4版本,若不满足可执行以下命令安装。
sudo apt-get install -y python3 python3-pip
分别在本地安装toolkit开发套件包、Kernels算子包、NNAL神经网络加速库。推荐从官网下载:https://www.hiascend.com/developer/download/community/result?module=cann&cann=8.2.RC1
3、Conda创建环境并激活
conda create -n sam2 python=3.10.0
conda activate sam2
sam2代码、权重准备
1、安装
git clone https://github.com/facebookresearch/sam2.git
cd sam2
pip install -e .
2、下载Checkpoints
cd checkpoints && \
./download_ckpts.sh && \
3、把配置文件 拷贝到项目根目录
cp -r sam2/configs/ ./
4、安装PyTorch框架和torch_npu插件
通过wheel格式的二进制软件包直接安装。参考文档:https://www.hiascend.com/document/detail/zh/Pytorch/710/configandinstg/instg/insg_0004.html
5、安装其他依赖:
pip install numpy==1.26.4
pip install matplotlib==3.10.7
pip install decorator==5.2.1
pip install scipy==1.15.3
pip install attrs==25.4.0
pip install psutil==7.1.0
pip install opencv-python==4.12.0.88
pip install numpy==1.26.4
使用NPU进行推理
参考代码如下:
import os
import time
import torch
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor
from hydra import initialize, compose
from hydra.core.global_hydra import GlobalHydra
def save_mask_with_alpha(image, mask, output_path):
# 确保mask是uint8类型,并且值在0或255(用于alpha通道)
mask = (mask * 255).astype(np.uint8)
# 创建一个新的image对象,模式为RGBA
rgba_image = Image.fromarray(image).convert('RGBA')
# 创建一个同样大小的alpha通道图像
alpha = Image.fromarray(mask).convert('L') # 'L'表示灰度图像
# 将alpha通道添加到rgba_image中
rgba_image.putalpha(alpha)
# 保存结果
rgba_image.save(output_path, format='PNG')
# 清理之前的 Hydra 初始化
GlobalHydra.instance().clear()
# 设置工作目录
os.chdir(os.path.dirname(os.path.abspath(__file__)) or '.')
print(f"Current working directory: {os.getcwd()}")
with initialize(config_path="configs/sam2.1", version_base=None):
print(" Hydra initialized")
# 设备选择(先判断)
if torch.cuda.is_available():
device = "cuda"
elif hasattr(torch, "npu") and torch.npu.is_available():
device = "npu"
else:
device = "cpu"
print(f"using device: {device}")
# 图像路径
image_path = '/mnt/liboran/models/sam2/notebooks/images/truck.jpg'
if not os.path.exists(image_path):
raise FileNotFoundError(f"Image not found: {image_path}")
image = np.array(Image.open(image_path).convert("RGB"))
# 检查点路径
sam2_checkpoint = "./checkpoints/sam2.1_hiera_large.pt"
if not os.path.exists(sam2_checkpoint):
raise FileNotFoundError(f"Checkpoint not found: {sam2_checkpoint}")
sam2_model = build_sam2("sam2.1_hiera_l", sam2_checkpoint, device=device)
predictor = SAM2ImagePredictor(sam2_model)
predictor.mask_threshold = 0.5
image = np.ascontiguousarray(image, dtype=np.uint8)
# 设置图像
predictor.set_image(image)
print(f"Image embedding shape: {predictor._features['image_embed'].shape}")
# 输入点
input_point = np.array([[500, 375]])
input_label = np.array([1])
# 推理
start_time = time.time()
masks, scores, logits = predictor.predict(
point_coords=input_point,
point_labels=input_label,
multimask_output=True,
)
end_time = time.time()
print(f"infer time is {end_time - start_time:.4f} seconds")
# 排序
sorted_ind = np.argsort(scores)[::-1]
masks = masks[sorted_ind]
scores = scores[sorted_ind]
logits = logits[sorted_ind]
def show_masks(image, masks, scores, point_coords=None, input_labels=None):
for i, (mask, score) in enumerate(zip(masks, scores)):
output_filename = f"Mask_{i+1}_Score_{score:.3f}.png"
save_mask_with_alpha(image, mask, output_filename)
print(f"Saved {output_filename}")
# 保存结果
show_masks(image, masks, scores, point_coords=input_point, input_labels=input_label)
print(" 推理完成,结果已保存。")
模型训练
下载训练数据集
https://zenodo.org/records/3697452/files/LabPicsV1.zip?download=1
下载的LabPics1数据集来分割材料和液体
SAM的主要工作方式是接收一张图像和图像中的一个点,然后预测包含该点的分割掩码。读取数据时,需要读取图片、图片中的所有分割掩码、已经每个掩码中的一个随机点
训练代码
import os
import torch
from contextlib import nullcontext
import numpy as np
import cv2
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor
from hydra import initialize, compose
from hydra.core.global_hydra import GlobalHydra
# 清理之前的 Hydra 初始化
if GlobalHydra.instance().is_initialized():
GlobalHydra.instance().clear()
# 设置工作目录
os.chdir(os.path.dirname(os.path.abspath(__file__)) or '.')
print(f"Current working directory: {os.getcwd()}")
# === 数据集路径和数据列表 ===
data_dir = r"/mnt/liboran/models/sam2/data/LabPicsV1/"
data = []
for name in os.listdir(data_dir + "Simple/Train/Image/"):
image_path = data_dir + "Simple/Train/Image/" + name
ann_path = data_dir + "Simple/Train/Instance/" + name[:-4] + ".png"
if os.path.exists(image_path) and os.path.exists(ann_path):
data.append({"image": image_path, "annotation": ann_path})
if len(data) == 0:
raise ValueError("No valid image/annotation pairs found in the specified directory.")
# === 修复后的 read_batch 函数(关键:返回单个 mask + 点)===
def read_batch(data):
"""从数据集中读取随机图像及其标注"""
max_attempts = 100
for _ in range(max_attempts):
ent = data[np.random.randint(len(data))]
img = cv2.imread(ent["image"])[..., ::-1] # BGR -> RGB
if img is None:
continue
ann_map = cv2.imread(ent["annotation"], cv2.IMREAD_UNCHANGED)
if ann_map is None:
continue
# 缩放图像
h, w = img.shape[:2]
r = min(1024 / w, 1024 / h)
new_w, new_h = int(w * r), int(h * r)
img = cv2.resize(img, (new_w, new_h))
ann_map = cv2.resize(ann_map, (new_w, new_h), interpolation=cv2.INTER_NEAREST)
# 合并标注逻辑(Mat + Ves)
mat_map = ann_map[:, :, 0].copy()
ves_map = ann_map[:, :, 2]
max_mat_id = mat_map.max()
mat_map[mat_map == 0] = ves_map[mat_map == 0] * (max_mat_id + 1)
# 获取所有实例
inds = np.unique(mat_map)[1:] # 跳过背景 0
if len(inds) == 0:
continue
ind = np.random.choice(inds) # 随机选一个实例训练(避免 bias)
mask = (mat_map == ind).astype(np.uint8)
coords = np.argwhere(mask > 0)
yx = coords[np.random.randint(len(coords))]
point = [int(yx[1]), int(yx[0])] # [x, y]
return img, mask, point, 1 # 返回单个 mask 和 point
print(" Warning: Failed to load valid batch after several attempts.")
return None, np.empty((0, 0)), [], 0
# === 训练流程 ===
with initialize(config_path="configs/sam2.1", version_base=None):
print(" Hydra initialized")
# 模型配置
sam2_checkpoint = "/mnt/liboran/models/sam2/checkpoints/sam2.1_hiera_small.pt"
model_cfg_name = "sam2.1_hiera_s"
# 判断设备
if torch.cuda.is_available():
device = "cuda"
elif hasattr(torch, "npu") and torch.npu.is_available():
device = "npu"
else:
device = "cpu"
print(f"Using device: {device}")
# 构建模型
sam2_model = build_sam2(model_cfg_name, sam2_checkpoint, device=device)
predictor = SAM2ImagePredictor(sam2_model)
# 启用训练模式(解码器和提示编码器)
predictor.model.sam_mask_decoder.train(True)
predictor.model.sam_prompt_encoder.train(True)
# 冻结图像编码器
for param in predictor.model.image_encoder.parameters():
param.requires_grad = False
# 优化器(只训练 prompt + mask decoder)
optimizer = torch.optim.AdamW(
[
{"params": predictor.model.sam_prompt_encoder.parameters()},
{"params": predictor.model.sam_mask_decoder.parameters()},
],
lr=5e-5,
weight_decay=4e-5
)
if device == "npu":
scaler = torch.npu.amp.GradScaler()
elif device == "cuda":
scaler = torch.cuda.amp.GradScaler()
else:
scaler = None
# 创建输出目录
os.makedirs("/mnt/liboran/models/sam2/checkpoints/train_output", exist_ok=True)
=== 计算 Mean IOU ===
iou_window = [] # 存储最近 N 个 IoU
window_size = 100 # 取最近 100 个样本平均
print(" Starting training loop...")
for itr in range(100000):
try:
# 获取 batch(单 mask + 单 point)
image, gt_mask_np, input_point, input_label = read_batch(data)
if image is None or gt_mask_np.size == 0:
continue
# 转为 tensor
input_point = torch.tensor([input_point], dtype=torch.float32, device=device).unsqueeze(0) # [1, 1, 2]
input_label = torch.tensor([[input_label]], dtype=torch.int32, device=device) # [1, 1]
# 混合精度上下文
with torch.npu.amp.autocast() if device == "npu" else \
torch.cuda.amp.autocast() if device == "cuda" else nullcontext():
predictor.set_image(image)
# 提示编码
sparse_embeddings, dense_embeddings = predictor.model.sam_prompt_encoder(
points=(input_point, input_label),
boxes=None,
masks=None,
)
# 高分辨率特征
high_res_features = [
feat_level[-1].unsqueeze(0).to(device)
for feat_level in predictor._features["high_res_feats"]
]
# 掩码解码
low_res_masks, _, _, _ = predictor.model.sam_mask_decoder(
image_embeddings=predictor._features["image_embed"][-1].unsqueeze(0).to(device),
image_pe=predictor.model.sam_prompt_encoder.get_dense_pe().to(device),
sparse_prompt_embeddings=sparse_embeddings,
dense_prompt_embeddings=dense_embeddings,
multimask_output=True,
high_res_features=high_res_features,
)
# 上采样
orig_hw = predictor._orig_hw[-1]
prd_masks = predictor._transforms.postprocess_masks(low_res_masks, orig_hw)
prd_mask_prob = torch.sigmoid(prd_masks[0, 0]) # [H, W]
# 将 GT 转为 tensor 并对齐尺寸
gt_mask = torch.tensor(gt_mask_np, dtype=torch.float32, device=device)
prd_mask_prob = torch.nn.functional.interpolate(
prd_mask_prob.unsqueeze(0).unsqueeze(0),
size=gt_mask.shape,
mode='bilinear',
align_corners=False
).squeeze()
# 二值化预测
pred_binary = (prd_mask_prob > 0.5).float()
# 计算 IoU(加 eps 防除零)
eps = 1e-6
intersection = (pred_binary * gt_mask).sum()
union = pred_binary.sum() + gt_mask.sum() - intersection
iou = (intersection + eps) / (union + eps)
iou_val = iou.item()
# 滑动窗口平均
iou_window.append(iou_val)
if len(iou_window) > window_size:
iou_window.pop(0)
mean_iou = np.mean(iou_window)
# 损失函数(BCE)
seg_loss = (-gt_mask * torch.log(prd_mask_prob + eps) -
(1 - gt_mask) * torch.log(1 - prd_mask_prob + eps)).mean()
# 反向传播
optimizer.zero_grad()
if scaler is not None:
scaler.scale(seg_loss).backward()
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(predictor.model.parameters(), max_norm=1.0)
scaler.step(optimizer)
scaler.update()
else:
seg_loss.backward()
torch.nn.utils.clip_grad_norm_(predictor.model.parameters(), max_norm=1.0)
optimizer.step()
# 打印日志
if itr % 100 == 0:
print(f"Step {itr}, Loss: {seg_loss.item():.4f}, Mean IOU: {mean_iou:.4f}")
if itr % 1000 == 0:
save_path = f"/mnt/liboran/models/sam2/checkpoints/train_output/sam2_finetuned_step_{itr}.pth"
torch.save(predictor.model.state_dict(), save_path)
print(f" Training success, model saved to {save_path}")
except Exception as e:
print(f" Error at step {itr}: {str(e)}")
continue
print(" Training completed.")
运行脚本
nohup python training-sam-new.py > training.log 2>&1 &
查看日志

- 点赞
- 收藏
- 关注作者
评论(0)