松材线虫病边缘模型训练与推理部署
松材线虫病边缘模型训练与推理部署
本文详细介绍了松材线虫病检测的边缘模型训练与推理部署全流程。首先,针对无人机拍摄的4032×3024原始图像进行预处理,缩放到1024×1024避免内存溢出,并定义了9个类别(包括麻栎、罩网、疑似、早期、轻度、中度、重度、死亡和逾年)。随后采用20%重叠率对图像进行切分,生成训练集60000张、验证集6495张的sahi数据集。模型训练基于yolo11s.yaml配置,在pwd数据集上进行10个Epoch的训练,虽然实际应用建议至少100个Epoch。评估结果显示,模型在pwd(重度)类别上表现最佳(mAP50达0.707),而pwd_early(早期)类别表现较差。为提升推理效率,将模型导出为TensorRT FP16引擎,GPU推理速度提升高达5倍,单张图片推理耗时约20ms。最后,通过Gradio构建了用户友好的检测应用,实现了松材线虫病的实时检测功能,为林业病害监测提供了有效的技术解决方案,具有较强的实用价值和推广前景。
1. 原始数据
无人机拍摄原始图像大小是4032 x 3024,这里缩放到1024 x 1024,避免在模型训练时内存溢出:
%%writefile pwd.yaml
# Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..]
path: /home/jetson/ultralytics/dataset/pwd # dataset root dir (absolute path)
train: train/images # train images (relative to 'path')
val: val/images # val images (relative to 'path')
test: # test images (optional)
# Classes,类别
names:
0: hardwood # 麻栎
1: net # 罩网
2: abnormal # 疑似
3: pwd_pre_early # 早期
4: pwd_early # 轻度
5: pwd_moderate # 中度
6: pwd # 重度
7: dead_recent # 死亡
8: dead # 逾年
Overwriting pwd.yaml
训练集3000张图像,验证集529张图像,查看验证集标注情况:
import os
import cv2
import yaml
import random
import numpy as np
from matplotlib import pyplot as plt
%matplotlib inline
with open('pwd.yaml', 'r', encoding='utf-8') as f:
data = yaml.load(f.read(), Loader=yaml.FullLoader)
classes = data['names']
file_path = os.path.join(data['path'], 'val/images')
file_list = os.listdir(file_path)
img_paths = random.sample(file_list, 4)
img_lists = []
for img_path in img_paths:
img_path = os.path.join(file_path, img_path)
img = cv2.imread(img_path)
h, w, _ = img.shape
tl = round(0.002 * (h + w) / 2) + 1
color = (0, 255, 255)
if img_path.endswith('.png'):
with open(img_path.replace("images", "labels").replace(".png", ".txt")) as f:
labels = f.readlines()
if img_path.endswith('.jpg'):
with open(img_path.replace("images", "labels").replace(".jpg", ".txt")) as f:
labels = f.readlines()
if img_path.endswith('.jpeg'):
with open(img_path.replace("images", "labels").replace(".jpeg", ".txt")) as f:
labels = f.readlines()
for label in labels:
l, x, y, wc, hc = [float(x) for x in label.strip().split()]
x1 = int((x - wc / 2) * w)
y1 = int((y - hc / 2) * h)
x2 = int((x + wc / 2) * w)
y2 = int((y + hc / 2) * h)
cv2.rectangle(img, (x1, y1), (x2, y2),
color, thickness=tl, lineType=cv2.LINE_AA)
cv2.putText(img,classes[int(l)],(x1, y1-10), cv2.FONT_HERSHEY_SIMPLEX, 1, color, 2, cv2.LINE_AA)
img_lists.append(cv2.resize(img, (1024, 1024)))
image = np.concatenate([np.concatenate(img_lists[:2], axis=1), np.concatenate(img_lists[2:], axis=1)], axis=0)
cv2.imwrite("sample-pwd.png", image)
plt.rcParams["figure.figsize"] = (16, 16)
plt.imshow(image[:,:,::-1])
plt.axis('off')
plt.show()

2. 切分数据
对dataset/pwd数据集进行图像切分,切分大小为1024 x 1024,重叠率是20%,生成新的数据集dataset/pwd-sahi:
%%writefile pwd-sahi.yaml
# Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..]
path: /home/jetson/ultralytics/dataset/pwd-sahi # dataset root dir (absolute path)
train: train/images # train images (relative to 'path')
val: val/images # val images (relative to 'path')
test: # test images (optional)
# Classes,类别
names:
0: hardwood # 麻栎
1: net # 罩网
2: abnormal # 疑似
3: pwd_pre_early # 早期
4: pwd_early # 轻度
5: pwd_moderate # 中度
6: pwd # 重度
7: dead_recent # 死亡
8: dead # 逾年
Overwriting pwd-sahi.yaml
其中训练集60000张图像(部分为背景图),验证集6495张图像(不含背景图),查看验证集的标注情况:
import os
import cv2
import yaml
import random
import numpy as np
from matplotlib import pyplot as plt
%matplotlib inline
with open('pwd-sahi.yaml', 'r', encoding='utf-8') as f:
data = yaml.load(f.read(), Loader=yaml.FullLoader)
classes = data['names']
file_path = os.path.join(data['path'], 'val/images')
file_list = os.listdir(file_path)
img_paths = random.sample(file_list, 4)
img_lists = []
for img_path in img_paths:
img_path = os.path.join(file_path, img_path)
img = cv2.imread(img_path)
h, w, _ = img.shape
tl = round(0.002 * (h + w) / 2) + 1
color = (0, 255, 255)
if img_path.endswith('.png'):
with open(img_path.replace("images", "labels").replace(".png", ".txt")) as f:
labels = f.readlines()
if img_path.endswith('.jpg'):
with open(img_path.replace("images", "labels").replace(".jpg", ".txt")) as f:
labels = f.readlines()
if img_path.endswith('.jpeg'):
with open(img_path.replace("images", "labels").replace(".jpeg", ".txt")) as f:
labels = f.readlines()
for label in labels:
l, x, y, wc, hc = [float(x) for x in label.strip().split()]
x1 = int((x - wc / 2) * w)
y1 = int((y - hc / 2) * h)
x2 = int((x + wc / 2) * w)
y2 = int((y + hc / 2) * h)
cv2.rectangle(img, (x1, y1), (x2, y2),
color, thickness=tl, lineType=cv2.LINE_AA)
cv2.putText(img,classes[int(l)],(x1, y1-10), cv2.FONT_HERSHEY_SIMPLEX, 1, color, 2, cv2.LINE_AA)
img_lists.append(cv2.resize(img, (1024, 1024)))
image = np.concatenate([np.concatenate(img_lists[:2], axis=1), np.concatenate(img_lists[2:], axis=1)], axis=0)
cv2.imwrite("sample-pwd-sahi.png", image)
plt.rcParams["figure.figsize"] = (16, 16)
plt.imshow(image[:,:,::-1])
plt.axis('off')
plt.show()

3. 模型训练
我们加载yolo11s.yaml模型的配置文件在dataset/pwd数据集上训练10个Epoch,模型的训练结果保存在pine_wilt_disease/yolo11s_10目录下:
%%writefile train.py
from ultralytics import YOLO
# Load a model
model = YOLO('yolo11s.yaml') # load yaml model
# Train the model
results = model.train(data='pwd.yaml', epochs=10, imgsz=640, workers=4, batch=8, project="pine_wilt_disease", name="yolo11s_10")
Overwriting train.py
在终端中运行:/home/jetson/ultralytics/train.sh

在另一个终端中运行/home/jetson/ultralytics/tensorboard.sh可以监控模型的训练情况:


4. 模型评估
加载训练好的模型,这里我们仅训练了10个Epoch,实际训练至少100个Epoch才能取得较好的效果:
from ultralytics import YOLO
# Load a model
model = YOLO('pine_wilt_disease/yolo11s_10/weights/best.pt') # load the best model
# Evaluate the model
metrics = model.val(
data='pwd.yaml', # 数据集配置
imgsz=640, # 模型输入大小
workers=4, # 数据加载线程
batch=8, # 验证批次大小
plots=True, # 生成验证结果图
split='val' # 指定使用验证集
)
Ultralytics 8.3.55 🚀 Python-3.10.12 torch-2.5.0a0+872d972e41.nv24.08 CUDA:0 (Orin, 7620MiB)
YOLO11s summary (fused): 238 layers, 9,416,283 parameters, 0 gradients, 21.3 GFLOPs
val: Scanning /home/jetson/ultralytics/dataset/pwd/val/labels.cache... 529 images, 0 backgrounds, 0 corrupt: 100%|██████████| 529/529 [00:00<?, ?it/s]
Class Images Instances Box(P R mAP50 mAP50-95): 100%|██████████| 67/67 [00:22<00:00, 2.92it/s]
all 529 8612 0.602 0.445 0.442 0.261
net 383 4210 0.681 0.625 0.683 0.401
pwd_early 167 375 1 0 0.0362 0.0168
pwd_moderate 253 503 0.4 0.225 0.224 0.108
pwd 383 1141 0.582 0.765 0.707 0.462
dead_recent 376 1456 0.503 0.443 0.444 0.26
dead 229 927 0.444 0.613 0.557 0.318
Speed: 0.9ms preprocess, 24.5ms inference, 0.0ms loss, 4.1ms postprocess per image
Results saved to runs/detect/val
注意,图片实际标注的类别只有6类,不包含麻栎和疑似。
5. 模型导出
导出到TensorRT,GPU推理速度提升高达5倍:
from ultralytics import YOLO
model = YOLO("pine_wilt_disease/yolo11s_10/weights/best.pt")
# TensorRT FP16
model.export(format="engine", imgsz=640, batch=1, half=True)
WARNING ⚠️ TensorRT requires GPU export, automatically assigning device=0
Ultralytics 8.3.55 🚀 Python-3.10.12 torch-2.5.0a0+872d972e41.nv24.08 CUDA:0 (Orin, 7620MiB)
YOLO11s summary (fused): 238 layers, 9,416,283 parameters, 0 gradients, 21.3 GFLOPs
[34m[1mPyTorch:[0m starting from 'pine_wilt_disease/yolo11s_10/weights/best.pt' with input shape (1, 3, 640, 640) BCHW and output shape(s) (1, 13, 8400) (18.3 MB)
[34m[1mONNX:[0m starting export with onnx 1.17.0 opset 19...
[34m[1mONNX:[0m slimming with onnxslim 0.1.47...
[34m[1mONNX:[0m export success ✅ 3.2s, saved as 'pine_wilt_disease/yolo11s_10/weights/best.onnx' (36.2 MB)
[34m[1mTensorRT:[0m starting export with TensorRT 10.7.0...
[11/15/2025-16:23:19] [TRT] [I] [MemUsageChange] Init CUDA: CPU -2, GPU +0, now: CPU 1395, GPU 7158 (MiB)
[11/15/2025-16:23:25] [TRT] [I] [MemUsageChange] Init builder kernel library: CPU +970, GPU +258, now: CPU 2322, GPU 7418 (MiB)
[11/15/2025-16:23:26] [TRT] [I] ----------------------------------------------------------------
[11/15/2025-16:23:26] [TRT] [I] Input filename: pine_wilt_disease/yolo11s_10/weights/best.onnx
[11/15/2025-16:23:26] [TRT] [I] ONNX IR version: 0.0.9
[11/15/2025-16:23:26] [TRT] [I] Opset version: 19
[11/15/2025-16:23:26] [TRT] [I] Producer name: pytorch
[11/15/2025-16:23:26] [TRT] [I] Producer version: 2.5.0
[11/15/2025-16:23:26] [TRT] [I] Domain:
[11/15/2025-16:23:26] [TRT] [I] Model version: 0
[11/15/2025-16:23:26] [TRT] [I] Doc string:
[11/15/2025-16:23:26] [TRT] [I] ----------------------------------------------------------------
[34m[1mTensorRT:[0m input "images" with shape(1, 3, 640, 640) DataType.FLOAT
[34m[1mTensorRT:[0m output "output0" with shape(1, 13, 8400) DataType.FLOAT
[34m[1mTensorRT:[0m building FP16 engine as pine_wilt_disease/yolo11s_10/weights/best.engine
[11/15/2025-16:23:26] [TRT] [I] Local timing cache in use. Profiling results in this builder pass will not be stored.
[11/15/2025-16:28:08] [TRT] [I] Compiler backend is used during engine build.
[11/15/2025-16:31:48] [TRT] [I] Detected 1 inputs and 1 output network tensors.
[11/15/2025-16:31:53] [TRT] [I] Total Host Persistent Memory: 543184 bytes
[11/15/2025-16:31:53] [TRT] [I] Total Device Persistent Memory: 0 bytes
[11/15/2025-16:31:53] [TRT] [I] Max Scratch Memory: 2764800 bytes
[11/15/2025-16:31:53] [TRT] [I] [BlockAssignment] Started assigning block shifts. This will take 162 steps to complete.
[11/15/2025-16:31:53] [TRT] [I] [BlockAssignment] Algorithm ShiftNTopDown took 19.653ms to assign 10 blocks to 162 nodes requiring 19046912 bytes.
[11/15/2025-16:31:53] [TRT] [I] Total Activation Memory: 19046400 bytes
[11/15/2025-16:31:53] [TRT] [I] Total Weights Memory: 18914082 bytes
[11/15/2025-16:31:53] [TRT] [I] Compiler backend is used during engine execution.
[11/15/2025-16:31:53] [TRT] [I] Engine generation completed in 506.948 seconds.
[11/15/2025-16:31:53] [TRT] [I] [MemUsageStats] Peak memory usage of TRT CPU/GPU memory allocators: CPU 2 MiB, GPU 140 MiB
[34m[1mTensorRT:[0m export success ✅ 519.0s, saved as 'pine_wilt_disease/yolo11s_10/weights/best.engine' (21.6 MB)
Export complete (519.7s)
Results saved to [1m/home/jetson/ultralytics/pine_wilt_disease/yolo11s_10/weights[0m
Predict: yolo predict task=detect model=pine_wilt_disease/yolo11s_10/weights/best.engine imgsz=640 half
Validate: yolo val task=detect model=pine_wilt_disease/yolo11s_10/weights/best.engine imgsz=640 data=pwd.yaml half
Visualize: https://netron.app
导出FP16精度的量化模型大概需要10分钟左右。
6. 模型推理
- 使用
TensorRT引擎加载模型对验证集的部分图片进行推理,每张图片的推理耗时约20ms:
import cv2
import glob
from ultralytics import YOLO
import matplotlib.pyplot as plt
%matplotlib inline
# Load the TensorRT engine model
model = YOLO("pine_wilt_disease/yolo11s_10/weights/best.engine")
# Define the prediction function
def predict(image_path):
reuslts = model.predict(image_path, conf=0.45, iou=0.55)
return reuslts[0].plot()
# Load the images for inference
images_path = glob.glob("dataset/pwd/val/images/*.jpeg")
# Perform inference and display results
for image_path in images_path[:10]:
result = predict(image_path)
result = cv2.cvtColor(result, cv2.COLOR_BGR2RGB)
result = cv2.resize(result, (4032 // 4, 3024 // 4))
plt.imshow(result)
plt.axis("off")
plt.show()
WARNING ⚠️ Unable to automatically guess model task, assuming 'task=detect'. Explicitly define task for your model, i.e. 'task=detect', 'segment', 'classify','pose' or 'obb'.
Loading pine_wilt_disease/yolo11s_10/weights/best.engine for TensorRT inference...
[11/15/2025-16:31:54] [TRT] [I] Loaded engine size: 21 MiB
[11/15/2025-16:31:54] [TRT] [I] [MemUsageChange] TensorRT-managed allocation in IExecutionContext creation: CPU +0, GPU +18, now: CPU 0, GPU 36 (MiB)
image 1/1 /home/jetson/ultralytics/dataset/pwd/val/images/1d1d160a-ae4f-4fe4-801d-f001d4e7ff6d.jpeg: 640x640 4 nets, 1 pwd, 1 dead_recent, 20.1ms
Speed: 45.6ms preprocess, 20.1ms inference, 49.3ms postprocess per image at shape (1, 3, 640, 640)
...

- 构建
Gradio应用程序,上传图片实现松材线虫病检测的功能:

至此,本章结束。
- 点赞
- 收藏
- 关注作者

评论(0)