faster rcnn的推理代码解析
【摘要】 推理代码的主要功能是使用Faster R-CNN目标检测模型对输入的测试图片进行目标检测,并可视化检测结果。 代码的整体框架和流程导入必要的库:包括matplotlib、numpy、cv2、torch、argparse等常用的科学计算和图像处理库。导入自定义模块,如tools._init_paths、model.config、model.test等。定义类别和模型路径:CLASSES定义了目...
推理代码的主要功能是使用Faster R-CNN目标检测模型对输入的测试图片进行目标检测,并可视化检测结果。
代码的整体框架和流程
-
导入必要的库:
- 包括
matplotlib
、numpy
、cv2
、torch
、argparse
等常用的科学计算和图像处理库。 - 导入自定义模块,如
tools._init_paths
、model.config
、model.test
等。
- 包括
-
定义类别和模型路径:
CLASSES
定义了目标检测的数据集类别,包括背景类。NETS
和DATASETS
定义了网络模型和数据集的路径模板。
-
目标检测结果可视化函数:
vis_detections
函数用于绘制检测到的物体边框和类别标签。
-
加载预训练模型并进行推理:
- 加载Faster R-CNN预训练模型(
vgg16_faster_rcnn_iter_110000.pth
)。 - 对测试文件夹中的图片进行目标检测,输出检测边框和置信度,并可视化结果。
- 加载Faster R-CNN预训练模型(
代码的关键部分解析
1. 目标检测可视化函数
def vis_detections(im, class_dets, thresh=0.5):
"""Draw detected bounding boxes."""
im = im[:, :, (2, 1, 0)] # OpenCV的图片格式是BGR,而matplotlib需要RGB格式
fig, ax = plt.subplots(figsize=(12, 12))
ax.imshow(im, aspect='equal') # 显示图片
for class_name in class_dets:
dets = class_dets[class_name]
inds = np.where(dets[:, -1] >= thresh)[0] # 筛选置信度大于阈值的检测结果
if len(inds) == 0:
continue
for i in inds:
bbox = dets[i, :4] # 获取边界框的坐标
score = dets[i, -1] # 获取置信度
# 绘制矩形框
ax.add_patch(plt.Rectangle((bbox[0], bbox[1]),
bbox[2] - bbox[0],
bbox[3] - bbox[1],
fill=False,
edgecolor='red',
linewidth=3.5))
# 添加类别标签和置信度
ax.text(bbox[0], bbox[1] - 2,
'{:s} {:.3f}'.format(class_name, score),
bbox=dict(facecolor='blue', alpha=0.5),
fontsize=14,
color='white')
plt.axis('off') # 去掉坐标轴
plt.tight_layout()
plt.draw()
作用:将检测到的目标以矩形框绘制在图像上,并标注类别和置信度。
2. 加载模型和权重
# 加载模型
net = vgg16() # 使用VGG16作为backbone
net.create_architecture(21, tag='default', anchor_scales=[8, 16, 32]) # 创建模型架构
net.load_state_dict(torch.load(saved_model, map_location=lambda storage, loc: storage)) # 加载预训练权重
net.eval() # 设置为评估模式
net.to(net._device) # 将模型移动到指定设备(CPU或GPU)
作用:
- 使用VGG16网络作为特征提取的backbone。
- 加载预训练的Faster R-CNN模型权重文件。
- 设置模型为评估模式,以确保推理过程中不会启用梯度计算。
3. 目标检测推理
# 检测图片
scores, boxes = im_detect(net, im) # 调用im_detect函数,获取检测得分和边界框
keep = nms(torch.from_numpy(cls_boxes), torch.from_numpy(cls_scores), NMS_THRESH) # 非极大值抑制
dets = dets[keep.numpy(), :] # 保留NMS后的边界框
作用:
im_detect
函数用于对输入图片进行目标检测,输出类别得分和预测边界框。- 使用非极大值抑制(NMS)筛选掉重复的边界框,只保留置信度最高的框。
4. 可视化结果
vis_detections(im, cls_dets, thresh=CONF_THRESH)
plt.show()
作用:
- 将检测结果绘制在图片上,并展示出来。
代码的运行流程
-
定义类别和模型路径:
- 设置目标检测数据集的类别和预训练模型路径。
-
加载模型:
- 加载Faster R-CNN预训练模型,并准备推理。
-
读取测试图片:
- 遍历测试文件夹中的图片,逐个进行检测。
-
目标检测:
- 使用
im_detect
函数对图片进行目标检测,获取边界框和置信度。
- 使用
-
非极大值抑制:
- 使用NMS去除重复的边界框。
-
可视化结果:
- 将检测到的目标绘制在图片上,并展示输出。
推理代码如下:
%matplotlib inline
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# 将路径转入lib
import tools._init_paths
from model.config import cfg
from model.test import im_detect
from torchvision.ops import nms
from utils.timer import Timer
import matplotlib.pyplot as plt
import numpy as np
import os, cv2
import argparse
from nets.vgg16 import vgg16
from nets.resnet_v1 import resnetv1
from model.bbox_transform import clip_boxes, bbox_transform_inv
import torch
# PASCAL VOC类别设置
CLASSES = ('__background__',
'aeroplane', 'bicycle', 'bird', 'boat',
'bottle', 'bus', 'car', 'cat', 'chair',
'cow', 'diningtable', 'dog', 'horse',
'motorbike', 'xperson', 'pottedplant',
'sheep', 'sofa', 'train', 'tvmonitor')
# 网络模型文件名定义
NETS = {'vgg16': ('vgg16_faster_rcnn_iter_%d.pth',),'res101': ('res101_faster_rcnn_iter_%d.pth',)}
# 数据集文件名定义
DATASETS= {'pascal_voc': ('voc_2007_trainval',),'pascal_voc_0712': ('voc_2007_trainval+voc_2012_trainval',)}
def vis_detections(im, class_dets, thresh=0.5):
"""Draw detected bounding boxes."""
im = im[:, :, (2, 1, 0)]
fig, ax = plt.subplots(figsize=(12, 12))
ax.imshow(im, aspect='equal')
for class_name in class_dets:
dets = class_dets[class_name]
inds = np.where(dets[:, -1] >= thresh)[0]
if len(inds) == 0:
continue
for i in inds:
bbox = dets[i, :4]
score = dets[i, -1]
ax.add_patch(
plt.Rectangle((bbox[0], bbox[1]),
bbox[2] - bbox[0],
bbox[3] - bbox[1], fill=False,
edgecolor='red', linewidth=3.5)
)
ax.text(bbox[0], bbox[1] - 2,
'{:s} {:.3f}'.format(class_name, score),
bbox=dict(facecolor='blue', alpha=0.5),
fontsize=14, color='white')
plt.axis('off')
plt.tight_layout()
plt.draw()
test_file = "./test"
import cv2
from utils.timer import Timer
from model.test import im_detect
from torchvision.ops import nms
cfg.TEST.HAS_RPN = True # Use RPN for proposals
# 模型存储位置
# 这里我们加载一个已经训练110000迭代之后的模型,可以选择自己的训练模型位置
#[ma-user Faster_R_CNN]$ll ./models/vgg16-voc0712/vgg16_faster_rcnn_iter_110000.pth
#-rwxr-x--- 1 ma-user ma-group 548317704 Aug 3 2017 ./models/vgg16-voc0712/vgg16_faster_rcnn_iter_110000.pth*
saved_model = "./models/vgg16-voc0712/vgg16_faster_rcnn_iter_110000.pth"
print('trying to load weights from ', saved_model)
# 加载backbone
net = vgg16()
# 构建网络
net.create_architecture(21, tag='default', anchor_scales=[8, 16, 32])
# 加载权重文件
net.load_state_dict(torch.load(saved_model, map_location=lambda storage, loc: storage))
net.eval()
# 选择推理设备
net.to(net._device)
print('Loaded network {:s}'.format(saved_model))
for file in os.listdir(test_file):
if file.startswith("._") == False:
file_path = os.path.join(test_file, file)
print(file_path)
# 打开测试图片文件
im = cv2.imread(file_path)
# 定义计时器
timer = Timer()
timer.tic()
# 检测得到图片ROI
scores, boxes = im_detect(net, im)
print(scores.shape, boxes.shape)
timer.toc()
print('Detection took {:.3f}s for {:d} object proposals'.format(timer.total_time(), boxes.shape[0]))
# 定义阈值
CONF_THRESH = 0.7
NMS_THRESH = 0.3
cls_dets = {}
# NMS 非极大值抑制操作,过滤边界框
for cls_ind, cls in enumerate(CLASSES[1:]):
cls_ind += 1 # 跳过 background
cls_boxes = boxes[:, 4*cls_ind:4*(cls_ind + 1)]
cls_scores = scores[:, cls_ind]
dets = np.hstack((cls_boxes,
cls_scores[:, np.newaxis])).astype(np.float32)
keep = nms(torch.from_numpy(cls_boxes), torch.from_numpy(cls_scores), NMS_THRESH)
dets = dets[keep.numpy(), :]
if len(dets) > 0:
if cls in cls_dets:
cls_dets[cls] = np.vstack([cls_dets[cls], dets])
else:
cls_dets[cls] = dets
vis_detections(im, cls_dets, thresh=CONF_THRESH)
plt.show()
附:代码里继续训练采用的PASCAL VOC 数据集,其共有 20种对象,具体如下:
- 人:
person
- 动物:
bird
(鸟)、cat
(猫)、cow
(牛)、dog
(狗)、horse
(马)、sheep
(羊) - 交通工具:
aeroplane
(飞机)、bicycle
(自行车)、boat
(船)、bus
(公共汽车)、car
(汽车)、motorbike
(摩托车)、train
(火车) - 室内物品:
bottle
(瓶子)、chair
(椅子)、diningtable
(餐桌)、pottedplant
(盆栽植物)、sofa
(沙发)、tvmonitor
(电视/显示器)
【声明】本内容来自华为云开发者社区博主,不代表华为云及华为云开发者社区的观点和立场。转载时必须标注文章的来源(华为云社区)、文章链接、文章作者等基本信息,否则作者和本社区有权追究责任。如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱:
cloudbbs@huaweicloud.com
- 点赞
- 收藏
- 关注作者
评论(0)