SAM2(Segment Anything 2)模型昇腾适配训练

举报
HuaweiCloudDeveloper 发表于 2026/01/07 10:36:17 2026/01/07
【摘要】 检查环境1、SSH登录机器后,检查NPU设备状态。运行如下命令,返回NPU设备信息。npu-smi info                    # 在每个实例节点上运行此命令可以看到NPU卡状态npu-smi info -l | grep Total    # 在每个实例节点上运行此命令可以看到总卡数,用来确认对应卡数已经挂载npu-smi info -t board -i 1 | eg...

检查环境

1SSH登录机器后,检查NPU设备状态。运行如下命令,返回NPU设备信息。

npu-smi info                    # 在每个实例节点上运行此命令可以看到NPU卡状态

npu-smi info -l | grep Total    # 在每个实例节点上运行此命令可以看到总卡数,用来确认对应卡数已经挂载

npu-smi info -t board -i 1 | egrep -i "software|firmware"   #查看驱动和固件版本

确保NPU设备正常安装

 

2 安装CANN

安装前需确保已具备Python环境及pip3,当前CANN支持Python3.7.x3.11.4版本,若不满足可执行以下命令安装。

sudo apt-get install -y python3 python3-pip

分别在本地安装toolkit开发套件包、Kernels算子包、NNAL神经网络加速库。推荐从官网下载:https://www.hiascend.com/developer/download/community/result?module=cann&cann=8.2.RC1

 

3Conda创建环境并激活

conda create -n sam2 python=3.10.0

conda activate sam2

 

sam2代码、权重准备

1、安装

git clone https://github.com/facebookresearch/sam2.git

cd sam2

pip install -e .

 

2、下载Checkpoints

cd checkpoints && \

./download_ckpts.sh && \

 

3、把配置文件 拷贝到项目根目录

cp -r  sam2/configs/ ./

 

4、安装PyTorch框架和torch_npu插件

通过wheel格式的二进制软件包直接安装。参考文档:https://www.hiascend.com/document/detail/zh/Pytorch/710/configandinstg/instg/insg_0004.html

 

5、安装其他依赖:

pip install numpy==1.26.4

pip install matplotlib==3.10.7

pip install decorator==5.2.1

pip install scipy==1.15.3

pip install attrs==25.4.0

pip install psutil==7.1.0

pip install opencv-python==4.12.0.88

pip install numpy==1.26.4


使用NPU进行推理

参考代码如下:

import os

import time

import torch

import numpy as np

from PIL import Image

import matplotlib.pyplot as plt

from sam2.build_sam import build_sam2

from sam2.sam2_image_predictor import SAM2ImagePredictor

from hydra import initialize, compose

from hydra.core.global_hydra import GlobalHydra

 

def save_mask_with_alpha(image, mask, output_path):

    # 确保maskuint8类型,并且值在0255(用于alpha通道)

    mask = (mask * 255).astype(np.uint8)

    # 创建一个新的image对象,模式为RGBA

    rgba_image = Image.fromarray(image).convert('RGBA')

    # 创建一个同样大小的alpha通道图像

    alpha = Image.fromarray(mask).convert('L')  # 'L'表示灰度图像

    # alpha通道添加到rgba_image

    rgba_image.putalpha(alpha)

    # 保存结果

    rgba_image.save(output_path, format='PNG')

 

# 清理之前的 Hydra 初始化

GlobalHydra.instance().clear()

 

# 设置工作目录

os.chdir(os.path.dirname(os.path.abspath(__file__)) or '.')

print(f"Current working directory: {os.getcwd()}")

 

with initialize(config_path="configs/sam2.1", version_base=None):

    print(" Hydra initialized")

 

    # 设备选择(先判断)

    if torch.cuda.is_available():

        device = "cuda"

    elif hasattr(torch, "npu") and torch.npu.is_available():

        device = "npu"

    else:

        device = "cpu"

 

    print(f"using device: {device}")

 

    # 图像路径

    image_path = '/mnt/liboran/models/sam2/notebooks/images/truck.jpg'

    if not os.path.exists(image_path):

        raise FileNotFoundError(f"Image not found: {image_path}")

    image = np.array(Image.open(image_path).convert("RGB"))

 

    # 检查点路径

    sam2_checkpoint = "./checkpoints/sam2.1_hiera_large.pt"

    if not os.path.exists(sam2_checkpoint):

        raise FileNotFoundError(f"Checkpoint not found: {sam2_checkpoint}")

 

    sam2_model = build_sam2("sam2.1_hiera_l", sam2_checkpoint, device=device)

    predictor = SAM2ImagePredictor(sam2_model)

    predictor.mask_threshold = 0.5

 

    image = np.ascontiguousarray(image, dtype=np.uint8)

    # 设置图像

    predictor.set_image(image)

    print(f"Image embedding shape: {predictor._features['image_embed'].shape}")

 

    # 输入点

    input_point = np.array([[500, 375]])

    input_label = np.array([1])

 

    # 推理

    start_time = time.time()

    masks, scores, logits = predictor.predict(

        point_coords=input_point,

        point_labels=input_label,

        multimask_output=True,

    )

    end_time = time.time()

    print(f"infer time is {end_time - start_time:.4f} seconds")

 

    # 排序

    sorted_ind = np.argsort(scores)[::-1]

    masks = masks[sorted_ind]

    scores = scores[sorted_ind]

    logits = logits[sorted_ind]

 

    def show_masks(image, masks, scores, point_coords=None, input_labels=None):

        for i, (mask, score) in enumerate(zip(masks, scores)):

            output_filename = f"Mask_{i+1}_Score_{score:.3f}.png"

            save_mask_with_alpha(image, mask, output_filename)

            print(f"Saved {output_filename}")

 

    # 保存结果

    show_masks(image, masks, scores, point_coords=input_point, input_labels=input_label)

    print(" 推理完成,结果已保存。")

 

模型训练

下载训练数据集

https://zenodo.org/records/3697452/files/LabPicsV1.zip?download=1

下载的LabPics1数据集来分割材料和液体

SAM的主要工作方式是接收一张图像和图像中的一个点,然后预测包含该点的分割掩码。读取数据时,需要读取图片、图片中的所有分割掩码、已经每个掩码中的一个随机点

 

训练代码

import os

import torch

from contextlib import nullcontext

import numpy as np

import cv2

 

from sam2.build_sam import build_sam2

from sam2.sam2_image_predictor import SAM2ImagePredictor

from hydra import initialize, compose

from hydra.core.global_hydra import GlobalHydra

 

# 清理之前的 Hydra 初始化

if GlobalHydra.instance().is_initialized():

    GlobalHydra.instance().clear()

 

# 设置工作目录

os.chdir(os.path.dirname(os.path.abspath(__file__)) or '.')

print(f"Current working directory: {os.getcwd()}")

 

# === 数据集路径和数据列表 ===

data_dir = r"/mnt/liboran/models/sam2/data/LabPicsV1/"

data = []

for name in os.listdir(data_dir + "Simple/Train/Image/"):

    image_path = data_dir + "Simple/Train/Image/" + name

    ann_path = data_dir + "Simple/Train/Instance/" + name[:-4] + ".png"

    if os.path.exists(image_path) and os.path.exists(ann_path):

        data.append({"image": image_path, "annotation": ann_path})

 

if len(data) == 0:

    raise ValueError("No valid image/annotation pairs found in the specified directory.")

 

# === 修复后的 read_batch 函数(关键:返回单个 mask + 点)===

def read_batch(data):

    """从数据集中读取随机图像及其标注"""

    max_attempts = 100

    for _ in range(max_attempts):

        ent = data[np.random.randint(len(data))]

        img = cv2.imread(ent["image"])[..., ::-1]  # BGR -> RGB

        if img is None:

            continue

        ann_map = cv2.imread(ent["annotation"], cv2.IMREAD_UNCHANGED)

        if ann_map is None:

            continue

 

        # 缩放图像

        h, w = img.shape[:2]

        r = min(1024 / w, 1024 / h)

        new_w, new_h = int(w * r), int(h * r)

        img = cv2.resize(img, (new_w, new_h))

        ann_map = cv2.resize(ann_map, (new_w, new_h), interpolation=cv2.INTER_NEAREST)

 

        # 合并标注逻辑(Mat + Ves

        mat_map = ann_map[:, :, 0].copy()

        ves_map = ann_map[:, :, 2]

        max_mat_id = mat_map.max()

        mat_map[mat_map == 0] = ves_map[mat_map == 0] * (max_mat_id + 1)

 

        # 获取所有实例

        inds = np.unique(mat_map)[1:]  # 跳过背景 0

        if len(inds) == 0:

            continue

 

        ind = np.random.choice(inds)  # 随机选一个实例训练(避免 bias

        mask = (mat_map == ind).astype(np.uint8)

        coords = np.argwhere(mask > 0)

        yx = coords[np.random.randint(len(coords))]

        point = [int(yx[1]), int(yx[0])]  # [x, y]

 

        return img, mask, point, 1  # 返回单个 mask point

 

    print(" Warning: Failed to load valid batch after several attempts.")

    return None, np.empty((0, 0)), [], 0

 

 

# === 训练流程 ===

with initialize(config_path="configs/sam2.1", version_base=None):

    print(" Hydra initialized")

 

    # 模型配置

    sam2_checkpoint = "/mnt/liboran/models/sam2/checkpoints/sam2.1_hiera_small.pt"

    model_cfg_name = "sam2.1_hiera_s"

 

    # 判断设备

    if torch.cuda.is_available():

        device = "cuda"

    elif hasattr(torch, "npu") and torch.npu.is_available():

        device = "npu"

    else:

        device = "cpu"

 

    print(f"Using device: {device}")

 

    # 构建模型

    sam2_model = build_sam2(model_cfg_name, sam2_checkpoint, device=device)

    predictor = SAM2ImagePredictor(sam2_model)

 

    # 启用训练模式(解码器和提示编码器)

    predictor.model.sam_mask_decoder.train(True)

    predictor.model.sam_prompt_encoder.train(True)

 

    # 冻结图像编码器

    for param in predictor.model.image_encoder.parameters():

        param.requires_grad = False

 

    # 优化器(只训练 prompt + mask decoder

    optimizer = torch.optim.AdamW(

        [

            {"params": predictor.model.sam_prompt_encoder.parameters()},

            {"params": predictor.model.sam_mask_decoder.parameters()},

        ],

        lr=5e-5,

        weight_decay=4e-5

    )

 

    if device == "npu":

        scaler = torch.npu.amp.GradScaler()

    elif device == "cuda":

        scaler = torch.cuda.amp.GradScaler()

    else:

        scaler = None

 

    # 创建输出目录

    os.makedirs("/mnt/liboran/models/sam2/checkpoints/train_output", exist_ok=True)

 

  === 计算 Mean IOU ===

    iou_window = []  # 存储最近 N IoU

    window_size = 100  # 取最近 100 个样本平均

 

    print(" Starting training loop...")

 

    for itr in range(100000):

        try:

            # 获取 batch(单 mask + point

            image, gt_mask_np, input_point, input_label = read_batch(data)

            if image is None or gt_mask_np.size == 0:

                continue

 

            # 转为 tensor

            input_point = torch.tensor([input_point], dtype=torch.float32, device=device).unsqueeze(0)  # [1, 1, 2]

            input_label = torch.tensor([[input_label]], dtype=torch.int32, device=device)  # [1, 1]

 

            # 混合精度上下文

            with torch.npu.amp.autocast() if device == "npu" else \

                 torch.cuda.amp.autocast() if device == "cuda" else nullcontext():

 

                predictor.set_image(image)

 

                # 提示编码

                sparse_embeddings, dense_embeddings = predictor.model.sam_prompt_encoder(

                    points=(input_point, input_label),

                    boxes=None,

                    masks=None,

                )

 

                # 高分辨率特征

                high_res_features = [

                    feat_level[-1].unsqueeze(0).to(device)

                    for feat_level in predictor._features["high_res_feats"]

                ]

 

                # 掩码解码

                low_res_masks, _, _, _ = predictor.model.sam_mask_decoder(

                    image_embeddings=predictor._features["image_embed"][-1].unsqueeze(0).to(device),

                    image_pe=predictor.model.sam_prompt_encoder.get_dense_pe().to(device),

                    sparse_prompt_embeddings=sparse_embeddings,

                    dense_prompt_embeddings=dense_embeddings,

                    multimask_output=True,

                    high_res_features=high_res_features,

                )

 

                # 上采样

                orig_hw = predictor._orig_hw[-1]

                prd_masks = predictor._transforms.postprocess_masks(low_res_masks, orig_hw)

                prd_mask_prob = torch.sigmoid(prd_masks[0, 0])  # [H, W]

 

            # GT 转为 tensor 并对齐尺寸

            gt_mask = torch.tensor(gt_mask_np, dtype=torch.float32, device=device)

            prd_mask_prob = torch.nn.functional.interpolate(

                prd_mask_prob.unsqueeze(0).unsqueeze(0),

                size=gt_mask.shape,

                mode='bilinear',

                align_corners=False

            ).squeeze()

 

            # 二值化预测

            pred_binary = (prd_mask_prob > 0.5).float()

 

            # 计算 IoU(加 eps 防除零)

            eps = 1e-6

            intersection = (pred_binary * gt_mask).sum()

            union = pred_binary.sum() + gt_mask.sum() - intersection

            iou = (intersection + eps) / (union + eps)

            iou_val = iou.item()

 

            # 滑动窗口平均

            iou_window.append(iou_val)

            if len(iou_window) > window_size:

                iou_window.pop(0)

            mean_iou = np.mean(iou_window)

 

            # 损失函数(BCE

            seg_loss = (-gt_mask * torch.log(prd_mask_prob + eps) -

                       (1 - gt_mask) * torch.log(1 - prd_mask_prob + eps)).mean()

 

            # 反向传播

            optimizer.zero_grad()

 

            if scaler is not None:

                scaler.scale(seg_loss).backward()

                scaler.unscale_(optimizer)

                torch.nn.utils.clip_grad_norm_(predictor.model.parameters(), max_norm=1.0)

                scaler.step(optimizer)

                scaler.update()

            else:

                seg_loss.backward()

                torch.nn.utils.clip_grad_norm_(predictor.model.parameters(), max_norm=1.0)

                optimizer.step()

 

            # 打印日志

            if itr % 100 == 0:

                print(f"Step {itr}, Loss: {seg_loss.item():.4f}, Mean IOU: {mean_iou:.4f}")

 

            if itr % 1000 == 0:

                save_path = f"/mnt/liboran/models/sam2/checkpoints/train_output/sam2_finetuned_step_{itr}.pth"

                torch.save(predictor.model.state_dict(), save_path)

                print(f" Training success, model saved to {save_path}")

 

        except Exception as e:

            print(f" Error at step {itr}: {str(e)}")

            continue

 

print(" Training completed.")

 

运行脚本

nohup python training-sam-new.py > training.log 2>&1 &

 

查看日志

【声明】本内容来自华为云开发者社区博主,不代表华为云及华为云开发者社区的观点和立场。转载时必须标注文章的来源(华为云社区)、文章链接、文章作者等基本信息,否则作者和本社区有权追究责任。如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱: cloudbbs@huaweicloud.com
  • 点赞
  • 收藏
  • 关注作者

评论(0

0/1000
抱歉,系统识别当前为高风险访问,暂不支持该操作

全部回复

上滑加载中

设置昵称

在此一键设置昵称,即可参与社区互动!

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。