松材线虫病边缘模型训练与推理部署

举报
HouYanSong 发表于 2025/11/15 17:10:12 2025/11/15
【摘要】 松材线虫病边缘模型训练与推理部署本文详细介绍了松材线虫病检测的边缘模型训练与推理部署全流程。首先,针对无人机拍摄的4032×3024原始图像进行预处理,缩放到1024×1024避免内存溢出,并定义了9个类别(包括麻栎、罩网、疑似、早期、轻度、中度、重度、死亡和逾年)。随后采用20%重叠率对图像进行切分,生成训练集60000张、验证集6495张的sahi数据集。模型训练基于yolo11s.y...

松材线虫病边缘模型训练与推理部署

本文详细介绍了松材线虫病检测的边缘模型训练与推理部署全流程。首先,针对无人机拍摄的4032×3024原始图像进行预处理,缩放到1024×1024避免内存溢出,并定义了9个类别(包括麻栎、罩网、疑似、早期、轻度、中度、重度、死亡和逾年)。随后采用20%重叠率对图像进行切分,生成训练集60000张、验证集6495张的sahi数据集。模型训练基于yolo11s.yaml配置,在pwd数据集上进行10Epoch的训练,虽然实际应用建议至少100Epoch。评估结果显示,模型在pwd(重度)类别上表现最佳(mAP500.707),而pwd_early(早期)类别表现较差。为提升推理效率,将模型导出为TensorRT FP16引擎,GPU推理速度提升高达5倍,单张图片推理耗时约20ms。最后,通过Gradio构建了用户友好的检测应用,实现了松材线虫病的实时检测功能,为林业病害监测提供了有效的技术解决方案,具有较强的实用价值和推广前景。

1. 原始数据

无人机拍摄原始图像大小是4032 x 3024,这里缩放到1024 x 1024,避免在模型训练时内存溢出:

%%writefile pwd.yaml
# Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..]
path: /home/jetson/ultralytics/dataset/pwd  # dataset root dir (absolute path)
train: train/images  # train images (relative to 'path') 
val: val/images  # val images (relative to 'path') 
test:  # test images (optional)

# Classes,类别
names:
  0: hardwood      # 麻栎
  1: net           # 罩网
  2: abnormal      # 疑似  
  3: pwd_pre_early # 早期
  4: pwd_early     # 轻度
  5: pwd_moderate  # 中度
  6: pwd           # 重度
  7: dead_recent   # 死亡
  8: dead          # 逾年
Overwriting pwd.yaml

训练集3000张图像,验证集529张图像,查看验证集标注情况:

import os
import cv2
import yaml
import random
import numpy as np
from  matplotlib import pyplot as plt
%matplotlib inline

with open('pwd.yaml', 'r', encoding='utf-8') as f:
    data = yaml.load(f.read(), Loader=yaml.FullLoader)
    
classes = data['names']
file_path = os.path.join(data['path'], 'val/images')
file_list = os.listdir(file_path)

img_paths = random.sample(file_list, 4)
img_lists = []

for img_path in img_paths:
    img_path = os.path.join(file_path, img_path)
    img = cv2.imread(img_path)
    h, w, _ = img.shape
    tl = round(0.002 * (h + w) / 2) + 1
    color = (0, 255, 255)
    if img_path.endswith('.png'):
        with open(img_path.replace("images", "labels").replace(".png", ".txt")) as f:
            labels = f.readlines()
    if img_path.endswith('.jpg'):
        with open(img_path.replace("images", "labels").replace(".jpg", ".txt")) as f:
            labels = f.readlines()
    if img_path.endswith('.jpeg'):
        with open(img_path.replace("images", "labels").replace(".jpeg", ".txt")) as f:
            labels = f.readlines()
    for label in labels:
        l, x, y, wc, hc = [float(x) for x in label.strip().split()]
        x1 = int((x - wc / 2) * w)
        y1 = int((y - hc / 2) * h)
        x2 = int((x + wc / 2) * w)
        y2 = int((y + hc / 2) * h)

        cv2.rectangle(img, (x1, y1), (x2, y2),
                      color, thickness=tl, lineType=cv2.LINE_AA)
        cv2.putText(img,classes[int(l)],(x1, y1-10), cv2.FONT_HERSHEY_SIMPLEX, 1, color, 2, cv2.LINE_AA)
    img_lists.append(cv2.resize(img, (1024, 1024)))

image = np.concatenate([np.concatenate(img_lists[:2], axis=1), np.concatenate(img_lists[2:], axis=1)], axis=0)
cv2.imwrite("sample-pwd.png", image)

plt.rcParams["figure.figsize"] = (16, 16)
plt.imshow(image[:,:,::-1])
plt.axis('off')
plt.show()

2. 切分数据

dataset/pwd数据集进行图像切分,切分大小为1024 x 1024,重叠率是20%,生成新的数据集dataset/pwd-sahi

%%writefile pwd-sahi.yaml
# Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..]
path: /home/jetson/ultralytics/dataset/pwd-sahi  # dataset root dir (absolute path)
train: train/images  # train images (relative to 'path') 
val: val/images  # val images (relative to 'path') 
test:  # test images (optional)

# Classes,类别
names:
  0: hardwood      # 麻栎
  1: net           # 罩网
  2: abnormal      # 疑似  
  3: pwd_pre_early # 早期
  4: pwd_early     # 轻度
  5: pwd_moderate  # 中度
  6: pwd           # 重度
  7: dead_recent   # 死亡
  8: dead          # 逾年
Overwriting pwd-sahi.yaml

其中训练集60000张图像(部分为背景图),验证集6495张图像(不含背景图),查看验证集的标注情况:

import os
import cv2
import yaml
import random
import numpy as np
from  matplotlib import pyplot as plt
%matplotlib inline

with open('pwd-sahi.yaml', 'r', encoding='utf-8') as f:
    data = yaml.load(f.read(), Loader=yaml.FullLoader)
    
classes = data['names']
file_path = os.path.join(data['path'], 'val/images')
file_list = os.listdir(file_path)

img_paths = random.sample(file_list, 4)
img_lists = []

for img_path in img_paths:
    img_path = os.path.join(file_path, img_path)
    img = cv2.imread(img_path)
    h, w, _ = img.shape
    tl = round(0.002 * (h + w) / 2) + 1
    color = (0, 255, 255)
    if img_path.endswith('.png'):
        with open(img_path.replace("images", "labels").replace(".png", ".txt")) as f:
            labels = f.readlines()
    if img_path.endswith('.jpg'):
        with open(img_path.replace("images", "labels").replace(".jpg", ".txt")) as f:
            labels = f.readlines()
    if img_path.endswith('.jpeg'):
        with open(img_path.replace("images", "labels").replace(".jpeg", ".txt")) as f:
            labels = f.readlines()
    for label in labels:
        l, x, y, wc, hc = [float(x) for x in label.strip().split()]
        x1 = int((x - wc / 2) * w)
        y1 = int((y - hc / 2) * h)
        x2 = int((x + wc / 2) * w)
        y2 = int((y + hc / 2) * h)

        cv2.rectangle(img, (x1, y1), (x2, y2),
                      color, thickness=tl, lineType=cv2.LINE_AA)
        cv2.putText(img,classes[int(l)],(x1, y1-10), cv2.FONT_HERSHEY_SIMPLEX, 1, color, 2, cv2.LINE_AA)
    img_lists.append(cv2.resize(img, (1024, 1024)))

image = np.concatenate([np.concatenate(img_lists[:2], axis=1), np.concatenate(img_lists[2:], axis=1)], axis=0)
cv2.imwrite("sample-pwd-sahi.png", image)

plt.rcParams["figure.figsize"] = (16, 16)
plt.imshow(image[:,:,::-1])
plt.axis('off')
plt.show()

3. 模型训练

我们加载yolo11s.yaml模型的配置文件在dataset/pwd数据集上训练10Epoch,模型的训练结果保存在pine_wilt_disease/yolo11s_10目录下:

%%writefile train.py
from ultralytics import YOLO

# Load a model
model = YOLO('yolo11s.yaml')  # load yaml model

# Train the model
results = model.train(data='pwd.yaml', epochs=10, imgsz=640,  workers=4, batch=8, project="pine_wilt_disease", name="yolo11s_10")
Overwriting train.py

在终端中运行:/home/jetson/ultralytics/train.sh

在另一个终端中运行/home/jetson/ultralytics/tensorboard.sh可以监控模型的训练情况:

4. 模型评估

加载训练好的模型,这里我们仅训练了10Epoch,实际训练至少100Epoch才能取得较好的效果:

from ultralytics import YOLO

# Load a model
model = YOLO('pine_wilt_disease/yolo11s_10/weights/best.pt')  # load the best model 

# Evaluate the model
metrics = model.val(
    data='pwd.yaml',      # 数据集配置
    imgsz=640,            # 模型输入大小
    workers=4,            # 数据加载线程
    batch=8,              # 验证批次大小
    plots=True,           # 生成验证结果图
    split='val'           # 指定使用验证集
)
Ultralytics 8.3.55 🚀 Python-3.10.12 torch-2.5.0a0+872d972e41.nv24.08 CUDA:0 (Orin, 7620MiB)
YOLO11s summary (fused): 238 layers, 9,416,283 parameters, 0 gradients, 21.3 GFLOPs
val: Scanning /home/jetson/ultralytics/dataset/pwd/val/labels.cache... 529 images, 0 backgrounds, 0 corrupt: 100%|██████████| 529/529 [00:00<?, ?it/s]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██████████| 67/67 [00:22<00:00,  2.92it/s]
                   all        529       8612      0.602      0.445      0.442      0.261
                   net        383       4210      0.681      0.625      0.683      0.401
             pwd_early        167        375          1          0     0.0362     0.0168
          pwd_moderate        253        503        0.4      0.225      0.224      0.108
                   pwd        383       1141      0.582      0.765      0.707      0.462
           dead_recent        376       1456      0.503      0.443      0.444       0.26
                  dead        229        927      0.444      0.613      0.557      0.318
Speed: 0.9ms preprocess, 24.5ms inference, 0.0ms loss, 4.1ms postprocess per image
Results saved to runs/detect/val

注意,图片实际标注的类别只有6类,不包含麻栎和疑似。

5. 模型导出

导出到TensorRTGPU推理速度提升高达5倍:

from ultralytics import YOLO

model = YOLO("pine_wilt_disease/yolo11s_10/weights/best.pt")
# TensorRT FP16
model.export(format="engine", imgsz=640, batch=1, half=True)
WARNING ⚠️ TensorRT requires GPU export, automatically assigning device=0
Ultralytics 8.3.55 🚀 Python-3.10.12 torch-2.5.0a0+872d972e41.nv24.08 CUDA:0 (Orin, 7620MiB)
YOLO11s summary (fused): 238 layers, 9,416,283 parameters, 0 gradients, 21.3 GFLOPs

PyTorch: starting from 'pine_wilt_disease/yolo11s_10/weights/best.pt' with input shape (1, 3, 640, 640) BCHW and output shape(s) (1, 13, 8400) (18.3 MB)

ONNX: starting export with onnx 1.17.0 opset 19...
ONNX: slimming with onnxslim 0.1.47...
ONNX: export success ✅ 3.2s, saved as 'pine_wilt_disease/yolo11s_10/weights/best.onnx' (36.2 MB)

TensorRT: starting export with TensorRT 10.7.0...
[11/15/2025-16:23:19] [TRT] [I] [MemUsageChange] Init CUDA: CPU -2, GPU +0, now: CPU 1395, GPU 7158 (MiB)
[11/15/2025-16:23:25] [TRT] [I] [MemUsageChange] Init builder kernel library: CPU +970, GPU +258, now: CPU 2322, GPU 7418 (MiB)
[11/15/2025-16:23:26] [TRT] [I] ----------------------------------------------------------------
[11/15/2025-16:23:26] [TRT] [I] Input filename:   pine_wilt_disease/yolo11s_10/weights/best.onnx
[11/15/2025-16:23:26] [TRT] [I] ONNX IR version:  0.0.9
[11/15/2025-16:23:26] [TRT] [I] Opset version:    19
[11/15/2025-16:23:26] [TRT] [I] Producer name:    pytorch
[11/15/2025-16:23:26] [TRT] [I] Producer version: 2.5.0
[11/15/2025-16:23:26] [TRT] [I] Domain:           
[11/15/2025-16:23:26] [TRT] [I] Model version:    0
[11/15/2025-16:23:26] [TRT] [I] Doc string:       
[11/15/2025-16:23:26] [TRT] [I] ----------------------------------------------------------------
TensorRT: input "images" with shape(1, 3, 640, 640) DataType.FLOAT
TensorRT: output "output0" with shape(1, 13, 8400) DataType.FLOAT
TensorRT: building FP16 engine as pine_wilt_disease/yolo11s_10/weights/best.engine
[11/15/2025-16:23:26] [TRT] [I] Local timing cache in use. Profiling results in this builder pass will not be stored.
[11/15/2025-16:28:08] [TRT] [I] Compiler backend is used during engine build.
[11/15/2025-16:31:48] [TRT] [I] Detected 1 inputs and 1 output network tensors.
[11/15/2025-16:31:53] [TRT] [I] Total Host Persistent Memory: 543184 bytes
[11/15/2025-16:31:53] [TRT] [I] Total Device Persistent Memory: 0 bytes
[11/15/2025-16:31:53] [TRT] [I] Max Scratch Memory: 2764800 bytes
[11/15/2025-16:31:53] [TRT] [I] [BlockAssignment] Started assigning block shifts. This will take 162 steps to complete.
[11/15/2025-16:31:53] [TRT] [I] [BlockAssignment] Algorithm ShiftNTopDown took 19.653ms to assign 10 blocks to 162 nodes requiring 19046912 bytes.
[11/15/2025-16:31:53] [TRT] [I] Total Activation Memory: 19046400 bytes
[11/15/2025-16:31:53] [TRT] [I] Total Weights Memory: 18914082 bytes
[11/15/2025-16:31:53] [TRT] [I] Compiler backend is used during engine execution.
[11/15/2025-16:31:53] [TRT] [I] Engine generation completed in 506.948 seconds.
[11/15/2025-16:31:53] [TRT] [I] [MemUsageStats] Peak memory usage of TRT CPU/GPU memory allocators: CPU 2 MiB, GPU 140 MiB
TensorRT: export success ✅ 519.0s, saved as 'pine_wilt_disease/yolo11s_10/weights/best.engine' (21.6 MB)

Export complete (519.7s)
Results saved to /home/jetson/ultralytics/pine_wilt_disease/yolo11s_10/weights
Predict:         yolo predict task=detect model=pine_wilt_disease/yolo11s_10/weights/best.engine imgsz=640 half 
Validate:        yolo val task=detect model=pine_wilt_disease/yolo11s_10/weights/best.engine imgsz=640 data=pwd.yaml half 
Visualize:       https://netron.app

导出FP16精度的量化模型大概需要10分钟左右。

6. 模型推理

  1. 使用TensorRT引擎加载模型对验证集的部分图片进行推理,每张图片的推理耗时约20ms
import cv2
import glob
from ultralytics import YOLO
import matplotlib.pyplot as plt
%matplotlib inline

# Load the TensorRT engine model
model = YOLO("pine_wilt_disease/yolo11s_10/weights/best.engine")
# Define the prediction function
def predict(image_path):
    reuslts = model.predict(image_path, conf=0.45, iou=0.55)
    return reuslts[0].plot()

# Load the images for inference
images_path = glob.glob("dataset/pwd/val/images/*.jpeg")
# Perform inference and display results
for image_path in images_path[:10]:
    result = predict(image_path)
    result = cv2.cvtColor(result, cv2.COLOR_BGR2RGB)
    result = cv2.resize(result, (4032 // 4, 3024 // 4))
    plt.imshow(result)
    plt.axis("off")
    plt.show()
WARNING ⚠️ Unable to automatically guess model task, assuming 'task=detect'. Explicitly define task for your model, i.e. 'task=detect', 'segment', 'classify','pose' or 'obb'.
Loading pine_wilt_disease/yolo11s_10/weights/best.engine for TensorRT inference...
[11/15/2025-16:31:54] [TRT] [I] Loaded engine size: 21 MiB
[11/15/2025-16:31:54] [TRT] [I] [MemUsageChange] TensorRT-managed allocation in IExecutionContext creation: CPU +0, GPU +18, now: CPU 0, GPU 36 (MiB)

image 1/1 /home/jetson/ultralytics/dataset/pwd/val/images/1d1d160a-ae4f-4fe4-801d-f001d4e7ff6d.jpeg: 640x640 4 nets, 1 pwd, 1 dead_recent, 20.1ms
Speed: 45.6ms preprocess, 20.1ms inference, 49.3ms postprocess per image at shape (1, 3, 640, 640)
...

  1. 构建Gradio应用程序,上传图片实现松材线虫病检测的功能:

至此,本章结束。

【声明】本内容来自华为云开发者社区博主,不代表华为云及华为云开发者社区的观点和立场。转载时必须标注文章的来源(华为云社区)、文章链接、文章作者等基本信息,否则作者和本社区有权追究责任。如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱: cloudbbs@huaweicloud.com
  • 点赞
  • 收藏
  • 关注作者

评论(0

0/1000
抱歉,系统识别当前为高风险访问,暂不支持该操作

全部回复

上滑加载中

设置昵称

在此一键设置昵称,即可参与社区互动!

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。