YOLOv10 + DeepSORT 目标跟踪实现

举报
鱼弦 发表于 2024/09/06 09:26:47 2024/09/06
【摘要】 YOLOv10 + DeepSORT 目标跟踪实现 介绍YOLOv10(You Only Look Once, Version 10)是一种先进的实时对象检测网络,能够在单次神经网络传递中预测多个目标。DeepSORT(Simple Online and Realtime Tracking with a Deep Association Metric)是一个常用的多目标跟踪算法,它结合了深...

YOLOv10 + DeepSORT 目标跟踪实现

介绍

YOLOv10(You Only Look Once, Version 10)是一种先进的实时对象检测网络,能够在单次神经网络传递中预测多个目标。DeepSORT(Simple Online and Realtime Tracking with a Deep Association Metric)是一个常用的多目标跟踪算法,它结合了深度学习的特征提取和经典的匈牙利算法进行关联匹配。

将YOLOv10与DeepSORT结合,可以实现高效精确的多目标跟踪,在视频监控、自动驾驶、智能交通等领域具有广泛应用。

应用使用场景

  • 视频监控:实时跟踪监控视频中的人员、车辆等关键目标。
  • 自动驾驶:识别并跟踪道路上的行人、车辆等,提高自动驾驶系统的安全性。
  • 智能交通:管理交通流量,检测和跟踪拥堵情况及违规行为。
  • 体育分析:跟踪运动员动作,提供战术分析和运动数据统计。

下面是一些关于视频监控、自动驾驶、智能交通和体育分析的代码示例。这些示例使用Python和一些流行的计算机视觉库,如OpenCV和TensorFlow。

视频监控:实时跟踪监控视频中的人员、车辆等关键目标

import cv2

# Load pre-trained model and video
net = cv2.dnn.readNetFromCaffe("deploy.prototxt", "res10_300x300_ssd_iter_140000.caffemodel")
video_capture = cv2.VideoCapture('video.mp4')

while True:
    ret, frame = video_capture.read()
    if not ret:
        break
        
    # Prepare the frame for detection
    blob = cv2.dnn.blobFromImage(frame, 1.0, (300, 300), (104.0, 177.0, 123.0))
    net.setInput(blob)
    detections = net.forward()

    for i in range(detections.shape[2]):
        confidence = detections[0, 0, i, 2]
        if confidence > 0.5:
            box = detections[0, 0, i, 3:7] * np.array([w, h, w, h])
            (startX, startY, endX, endY) = box.astype("int")
            cv2.rectangle(frame, (startX, startY), (endX, endY), (0, 255, 0), 2)
            
    cv2.imshow('Video', frame)

    if cv2.waitKey(1) & 0xFF == ord('q'):
        break

video_capture.release()
cv2.destroyAllWindows()

自动驾驶:识别并跟踪道路上的行人、车辆等,提高自动驾驶系统的安全性

import tensorflow as tf
import numpy as np
import cv2

model = tf.saved_model.load('ssd_mobilenet_v2_fpnlite')
category_index = {1: {"id": 1, "name": "person"}, 2: {"id": 2, "name": "vehicle"}}

cap = cv2.VideoCapture('driving_video.mp4')

while cap.isOpened():
    ret, frame = cap.read()
    if not ret:
        break
        
    input_tensor = tf.convert_to_tensor(np.expand_dims(frame, 0), dtype=tf.float32)
    detections = model(input_tensor)

    for i in range(int(detections.pop('num_detections'))):
        class_id = int(detections['detection_classes'][i].numpy())
        score = float(detections['detection_scores'][i].numpy())
        bbox = detections['detection_boxes'][i].numpy()

        if score > 0.5:
            h, w, _ = frame.shape
            ymin, xmin, ymax, xmax = bbox
            left, right, top, bottom = int(xmin * w), int(xmax * w), int(ymin * h), int(ymax * h)
            label = category_index[class_id]["name"]
            cv2.rectangle(frame, (left, top), (right, bottom), (0, 255, 0), 2)
            cv2.putText(frame, f"{label}: {score:.2f}", (left, top - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 0, 0), 2)
            
    cv2.imshow('Driving Video', frame)

    if cv2.waitKey(1) & 0xFF == ord('q'):
        break

cap.release()
cv2.destroyAllWindows()

智能交通:管理交通流量,检测和跟踪拥堵情况及违规行为

import cv2

car_cascade = cv2.CascadeClassifier('cars.xml')
cap = cv2.VideoCapture('traffic.mp4')

while cap.isOpened():
    ret, frame = cap.read()
    if not ret:
        break
    
    gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
    cars = car_cascade.detectMultiScale(gray, 1.1, 1)
    
    for (x, y, w, h) in cars:
        cv2.rectangle(frame, (x, y), (x + w, y + h), (0, 0, 255), 2)
        
    cv2.imshow('Traffic Monitoring', frame)

    if cv2.waitKey(1) & 0xFF == ord('q'):
        break

cap.release()
cv2.destroyAllWindows()

体育分析:跟踪运动员动作,提供战术分析和运动数据统计

import cv2
import mediapipe as mp

mp_pose = mp.solutions.pose
pose = mp_pose.Pose()

cap = cv2.VideoCapture('sports.mp4')

while cap.isOpened():
    ret, frame = cap.read()
    if not ret:
        break

    image_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
    results = pose.process(image_rgb)

    if results.pose_landmarks:
        mp.solutions.drawing_utils.draw_landmarks(
            frame, results.pose_landmarks, mp_pose.POSE_CONNECTIONS)

    cv2.imshow('Sports Analysis', frame)

    if cv2.waitKey(1) & 0xFF == ord('q'):
        break

cap.release()
cv2.destroyAllWindows()

这些代码示例使用了不同的库和模型来实现各自的功能。请根据具体需求安装相应的依赖库,例如opencv-python, tensorflow, mediapipe等。

原理解释

YOLOv10

YOLOv10利用卷积神经网络(CNN)对输入图像进行处理,将图像分成SxS的网格,每个网格负责预测若干个边界框及其置信度和类别概率。网络通过非极大值抑制(NMS)去除冗余的盒子,并输出最终的目标检测结果。

DeepSORT

DeepSORT通过卡尔曼滤波器进行状态估计,并使用深度学习特征(ReID特征)进行目标关联匹配。其工作流程可以概括为以下几个步骤:

  1. 检测:使用YOLOv10获取目标检测框。
  2. 特征提取:从检测框中提取目标的ReID特征。
  3. 卡尔曼滤波:预测目标在下一时刻的位置。
  4. 数据关联:使用匈牙利算法将检测框与预测框进行关联匹配。
  5. 更新轨迹:基于匹配结果更新目标的轨迹信息。

算法原理流程图

输入图像
YOLOv10
检测框
提取ReID特征
卡尔曼滤波预测
匈牙利算法关联
更新轨迹
输出跟踪结果

实际应用代码示例实现

环境准备

pip install torch torchvision torchaudio
pip install yolov10_deepsort

YOLOv10 + DeepSORT 实现代码

import cv2
from yolov10 import YOLOv10
from deepsort import DeepSort

# 初始化YOLOv10模型
yolo = YOLOv10('path_to_yolov10_weights')
deep_sort = DeepSort('path_to_deepsort_checkpoint')

# 打开视频文件或摄像头
cap = cv2.VideoCapture('video.mp4')

while True:
    ret, frame = cap.read()
    if not ret:
        break
    
    # 使用YOLOv10检测目标
    detections = yolo.detect(frame)
    
    # 使用DeepSORT进行目标跟踪
    tracked_objects = deep_sort.update(detections, frame)
    
    # 绘制检测框和跟踪轨迹
    for obj in tracked_objects:
        bbox = obj['bbox']
        id = obj['id']
        cv2.rectangle(frame, (bbox[0], bbox[1]), (bbox[2], bbox[3]), (0, 255, 0), 2)
        cv2.putText(frame, f'ID: {id}', (bbox[0], bbox[1] - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (36,255,12), 2)

    cv2.imshow('Frame', frame)
    if cv2.waitKey(1) & 0xFF == ord('q'):
        break

cap.release()
cv2.destroyAllWindows()

测试代码

上述示例代码即可完成测试。运行后会打开视频并显示检测和跟踪结果。

部署场景

  • 云端部署:可将模型部署在云端,提供RESTful API接口供客户端调用。
  • 本地部署:用于如监控摄像头等本地设备,直接在设备上运行模型。
  • 边缘计算:在如嵌入式设备、树莓派等低功耗设备上运行,实现实时跟踪。

材料链接

总结

结合YOLOv10和DeepSORT的目标跟踪方法,通过高效的目标检测和精确的目标关联匹配,实现了实时、多目标、高鲁棒性的跟踪。这种方法在多种实际场景中展现出卓越的性能,是当前目标跟踪领域的重要技术之一。

未来展望

未来,随着深度学习和计算资源的发展,这类目标跟踪技术将进一步提高精度和效率。我们期望看到更多创新的模型和算法,如结合Transformer、自监督学习等技术,实现更智能、更高效的目标跟踪解决方案。同时,针对特定应用场景的优化与适配也将成为研究的重点方向。

【版权声明】本文为华为云社区用户原创内容,转载时必须标注文章的来源(华为云社区)、文章链接、文章作者等基本信息, 否则作者和本社区有权追究责任。如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱: cloudbbs@huaweicloud.com
  • 点赞
  • 收藏
  • 关注作者

评论(0

0/1000
抱歉,系统识别当前为高风险访问,暂不支持该操作

全部回复

上滑加载中

设置昵称

在此一键设置昵称,即可参与社区互动!

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。