faster rcnn的推理代码解析

举报
黄生 发表于 2025/02/01 14:41:36 2025/02/01
【摘要】 推理代码的主要功能是使用Faster R-CNN目标检测模型对输入的测试图片进行目标检测,并可视化检测结果。 代码的整体框架和流程导入必要的库:包括matplotlib、numpy、cv2、torch、argparse等常用的科学计算和图像处理库。导入自定义模块,如tools._init_paths、model.config、model.test等。定义类别和模型路径:CLASSES定义了目...

推理代码的主要功能是使用Faster R-CNN目标检测模型对输入的测试图片进行目标检测,并可视化检测结果。

代码的整体框架和流程

  1. 导入必要的库

    • 包括matplotlibnumpycv2torchargparse等常用的科学计算和图像处理库。
    • 导入自定义模块,如tools._init_pathsmodel.configmodel.test等。
  2. 定义类别和模型路径

    • CLASSES定义了目标检测的数据集类别,包括背景类。
    • NETSDATASETS定义了网络模型和数据集的路径模板。
  3. 目标检测结果可视化函数

    • vis_detections函数用于绘制检测到的物体边框和类别标签。
  4. 加载预训练模型并进行推理

    • 加载Faster R-CNN预训练模型(vgg16_faster_rcnn_iter_110000.pth)。
    • 对测试文件夹中的图片进行目标检测,输出检测边框和置信度,并可视化结果。

代码的关键部分解析

1. 目标检测可视化函数

def vis_detections(im, class_dets, thresh=0.5):
    """Draw detected bounding boxes."""
    im = im[:, :, (2, 1, 0)]  # OpenCV的图片格式是BGR,而matplotlib需要RGB格式
    fig, ax = plt.subplots(figsize=(12, 12))
    ax.imshow(im, aspect='equal')  # 显示图片
    for class_name in class_dets:
        dets = class_dets[class_name]
        inds = np.where(dets[:, -1] >= thresh)[0]  # 筛选置信度大于阈值的检测结果
        if len(inds) == 0:
            continue
        for i in inds:
            bbox = dets[i, :4]  # 获取边界框的坐标
            score = dets[i, -1]  # 获取置信度
            # 绘制矩形框
            ax.add_patch(plt.Rectangle((bbox[0], bbox[1]),
                                       bbox[2] - bbox[0],
                                       bbox[3] - bbox[1],
                                       fill=False,
                                       edgecolor='red',
                                       linewidth=3.5))
            # 添加类别标签和置信度
            ax.text(bbox[0], bbox[1] - 2,
                    '{:s} {:.3f}'.format(class_name, score),
                    bbox=dict(facecolor='blue', alpha=0.5),
                    fontsize=14,
                    color='white')
    plt.axis('off')  # 去掉坐标轴
    plt.tight_layout()
    plt.draw()

作用:将检测到的目标以矩形框绘制在图像上,并标注类别和置信度。

2. 加载模型和权重

# 加载模型
net = vgg16()  # 使用VGG16作为backbone
net.create_architecture(21, tag='default', anchor_scales=[8, 16, 32])  # 创建模型架构
net.load_state_dict(torch.load(saved_model, map_location=lambda storage, loc: storage))  # 加载预训练权重
net.eval()  # 设置为评估模式
net.to(net._device)  # 将模型移动到指定设备(CPU或GPU)

作用

  • 使用VGG16网络作为特征提取的backbone。
  • 加载预训练的Faster R-CNN模型权重文件。
  • 设置模型为评估模式,以确保推理过程中不会启用梯度计算。

3. 目标检测推理

# 检测图片
scores, boxes = im_detect(net, im)  # 调用im_detect函数,获取检测得分和边界框
keep = nms(torch.from_numpy(cls_boxes), torch.from_numpy(cls_scores), NMS_THRESH)  # 非极大值抑制
dets = dets[keep.numpy(), :]  # 保留NMS后的边界框

作用

  • im_detect函数用于对输入图片进行目标检测,输出类别得分和预测边界框。
  • 使用非极大值抑制(NMS)筛选掉重复的边界框,只保留置信度最高的框。

4. 可视化结果

vis_detections(im, cls_dets, thresh=CONF_THRESH)
plt.show()

作用

  • 将检测结果绘制在图片上,并展示出来。

代码的运行流程

  1. 定义类别和模型路径

    • 设置目标检测数据集的类别和预训练模型路径。
  2. 加载模型

    • 加载Faster R-CNN预训练模型,并准备推理。
  3. 读取测试图片

    • 遍历测试文件夹中的图片,逐个进行检测。
  4. 目标检测

    • 使用im_detect函数对图片进行目标检测,获取边界框和置信度。
  5. 非极大值抑制

    • 使用NMS去除重复的边界框。
  6. 可视化结果

    • 将检测到的目标绘制在图片上,并展示输出。

推理代码如下:

%matplotlib inline
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# 将路径转入lib
import tools._init_paths

from model.config import cfg
from model.test import im_detect
from torchvision.ops import nms

from utils.timer import Timer
import matplotlib.pyplot as plt
import numpy as np
import os, cv2
import argparse

from nets.vgg16 import vgg16
from nets.resnet_v1 import resnetv1
from model.bbox_transform import clip_boxes, bbox_transform_inv

import torch
# PASCAL VOC类别设置
CLASSES = ('__background__',
           'aeroplane', 'bicycle', 'bird', 'boat',
           'bottle', 'bus', 'car', 'cat', 'chair',
           'cow', 'diningtable', 'dog', 'horse',
           'motorbike', 'xperson', 'pottedplant',
           'sheep', 'sofa', 'train', 'tvmonitor')
# 网络模型文件名定义
NETS = {'vgg16': ('vgg16_faster_rcnn_iter_%d.pth',),'res101': ('res101_faster_rcnn_iter_%d.pth',)}
# 数据集文件名定义
DATASETS= {'pascal_voc': ('voc_2007_trainval',),'pascal_voc_0712': ('voc_2007_trainval+voc_2012_trainval',)}
def vis_detections(im, class_dets, thresh=0.5):
    """Draw detected bounding boxes."""
    im = im[:, :, (2, 1, 0)]
    fig, ax = plt.subplots(figsize=(12, 12))
    ax.imshow(im, aspect='equal')
    for class_name in class_dets:
        dets = class_dets[class_name]
        inds = np.where(dets[:, -1] >= thresh)[0]
        if len(inds) == 0:
            continue
        
        for i in inds:
            bbox = dets[i, :4]
            score = dets[i, -1]

            ax.add_patch(
                plt.Rectangle((bbox[0], bbox[1]),
                              bbox[2] - bbox[0],
                              bbox[3] - bbox[1], fill=False,
                              edgecolor='red', linewidth=3.5)
                )
            ax.text(bbox[0], bbox[1] - 2,
                    '{:s} {:.3f}'.format(class_name, score),
                    bbox=dict(facecolor='blue', alpha=0.5),
                    fontsize=14, color='white')

        plt.axis('off')
        plt.tight_layout()
        plt.draw()
test_file = "./test"
import cv2
from utils.timer import Timer
from model.test import im_detect
from torchvision.ops import nms

cfg.TEST.HAS_RPN = True  # Use RPN for proposals

# 模型存储位置
# 这里我们加载一个已经训练110000迭代之后的模型,可以选择自己的训练模型位置
#[ma-user Faster_R_CNN]$ll ./models/vgg16-voc0712/vgg16_faster_rcnn_iter_110000.pth
#-rwxr-x--- 1 ma-user ma-group 548317704 Aug  3  2017 ./models/vgg16-voc0712/vgg16_faster_rcnn_iter_110000.pth*
saved_model = "./models/vgg16-voc0712/vgg16_faster_rcnn_iter_110000.pth"
print('trying to load weights from ', saved_model)

# 加载backbone
net = vgg16()
# 构建网络
net.create_architecture(21, tag='default', anchor_scales=[8, 16, 32])
# 加载权重文件
net.load_state_dict(torch.load(saved_model, map_location=lambda storage, loc: storage))

net.eval()
# 选择推理设备
net.to(net._device)

print('Loaded network {:s}'.format(saved_model))

for file in os.listdir(test_file):
    if file.startswith("._") == False:
        file_path = os.path.join(test_file, file)
        print(file_path)
        # 打开测试图片文件
        im = cv2.imread(file_path)

        # 定义计时器
        timer = Timer()
        timer.tic()
        # 检测得到图片ROI
        scores, boxes = im_detect(net, im)

        print(scores.shape, boxes.shape)
        timer.toc()
        print('Detection took {:.3f}s for {:d} object proposals'.format(timer.total_time(), boxes.shape[0]))

        # 定义阈值
        CONF_THRESH = 0.7
        NMS_THRESH = 0.3

        cls_dets = {}

        # NMS 非极大值抑制操作,过滤边界框
        for cls_ind, cls in enumerate(CLASSES[1:]):
            cls_ind += 1 # 跳过 background
            cls_boxes = boxes[:, 4*cls_ind:4*(cls_ind + 1)]
            cls_scores = scores[:, cls_ind]
            dets = np.hstack((cls_boxes,
                              cls_scores[:, np.newaxis])).astype(np.float32)
            keep = nms(torch.from_numpy(cls_boxes), torch.from_numpy(cls_scores), NMS_THRESH)
            dets = dets[keep.numpy(), :]

            if len(dets) > 0:
                if cls in cls_dets:
                    cls_dets[cls] = np.vstack([cls_dets[cls], dets]) 
                else:
                    cls_dets[cls] = dets
        vis_detections(im, cls_dets, thresh=CONF_THRESH)

        plt.show()

附:代码里继续训练采用的PASCAL VOC 数据集,其共有 20种对象,具体如下:

  • person
  • 动物bird(鸟)、cat(猫)、cow(牛)、dog(狗)、horse(马)、sheep(羊)
  • 交通工具aeroplane(飞机)、bicycle(自行车)、boat(船)、bus(公共汽车)、car(汽车)、motorbike(摩托车)、train(火车)
  • 室内物品bottle(瓶子)、chair(椅子)、diningtable(餐桌)、pottedplant(盆栽植物)、sofa(沙发)、tvmonitor(电视/显示器)

notebook链接:https://developer.huaweicloud.com/develop/aigallery/notebook/detail?id=577028db-9eae-48b4-88b3-9f46169b8515&ticket=ST-8276091-Kk3RAodY51Lc96G7nuKi2dPH-sso&locale=zh-cn

【声明】本内容来自华为云开发者社区博主,不代表华为云及华为云开发者社区的观点和立场。转载时必须标注文章的来源(华为云社区)、文章链接、文章作者等基本信息,否则作者和本社区有权追究责任。如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱: cloudbbs@huaweicloud.com
  • 点赞
  • 收藏
  • 关注作者

评论(0

0/1000
抱歉,系统识别当前为高风险访问,暂不支持该操作

全部回复

上滑加载中

设置昵称

在此一键设置昵称,即可参与社区互动!

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。