物体检测-Faster R-CNN(2)
【摘要】 测试部分在这部分中,我们利用训练得到的模型进行推理测试。In [13]:%matplotlib inlinefrom __future__ import absolute_importfrom __future__ import divisionfrom __future__ import print_function# 将路径转入libimport tools._init_pathsfro...
In [13]:
%matplotlib inline
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# 将路径转入lib
import tools._init_paths
from model.config import cfg
from model.test import im_detect
from torchvision.ops import nms
from utils.timer import Timer
import matplotlib.pyplot as plt
import numpy as np
import os, cv2
import argparse
from nets.vgg16 import vgg16
from nets.resnet_v1 import resnetv1
from model.bbox_transform import clip_boxes, bbox_transform_inv
import torch
In [14]:
# PASCAL VOC类别设置
CLASSES = ('__background__',
'aeroplane', 'bicycle', 'bird', 'boat',
'bottle', 'bus', 'car', 'cat', 'chair',
'cow', 'diningtable', 'dog', 'horse',
'motorbike', 'person', 'pottedplant',
'sheep', 'sofa', 'train', 'tvmonitor')
# 网络模型文件名定义
NETS = {'vgg16': ('vgg16_faster_rcnn_iter_%d.pth',),'res101': ('res101_faster_rcnn_iter_%d.pth',)}
# 数据集文件名定义
DATASETS= {'pascal_voc': ('voc_2007_trainval',),'pascal_voc_0712': ('voc_2007_trainval+voc_2012_trainval',)}
In [15]:
def vis_detections(im, class_dets, thresh=0.5):
"""Draw detected bounding boxes."""
im = im[:, :, (2, 1, 0)]
fig, ax = plt.subplots(figsize=(12, 12))
ax.imshow(im, aspect='equal')
for class_name in class_dets:
dets = class_dets[class_name]
inds = np.where(dets[:, -1] >= thresh)[0]
if len(inds) == 0:
continue
for i in inds:
bbox = dets[i, :4]
score = dets[i, -1]
ax.add_patch(
plt.Rectangle((bbox[0], bbox[1]),
bbox[2] - bbox[0],
bbox[3] - bbox[1], fill=False,
edgecolor='red', linewidth=3.5)
)
ax.text(bbox[0], bbox[1] - 2,
'{:s} {:.3f}'.format(class_name, score),
bbox=dict(facecolor='blue', alpha=0.5),
fontsize=14, color='white')
plt.axis('off')
plt.tight_layout()
plt.draw()
In [16]:
test_file = "./test"
In [17]:
import cv2
from utils.timer import Timer
from model.test import im_detect
from torchvision.ops import nms
cfg.TEST.HAS_RPN = True # Use RPN for proposals
# 模型存储位置
# 这里我们加载一个已经训练110000迭代之后的模型,可以选择自己的训练模型位置
saved_model = "./models/vgg16-voc0712/vgg16_faster_rcnn_iter_110000.pth"
print('trying to load weights from ', saved_model)
# 加载backbone
net = vgg16()
# 构建网络
net.create_architecture(21, tag='default', anchor_scales=[8, 16, 32])
# 加载权重文件
net.load_state_dict(torch.load(saved_model, map_location=lambda storage, loc: storage))
net.eval()
# 选择推理设备
net.to(net._device)
print('Loaded network {:s}'.format(saved_model))
for file in os.listdir(test_file):
if file.startswith("._") == False:
file_path = os.path.join(test_file, file)
print(file_path)
# 打开测试图片文件
im = cv2.imread(file_path)
# 定义计时器
timer = Timer()
timer.tic()
# 检测得到图片ROI
scores, boxes = im_detect(net, im)
print(scores.shape, boxes.shape)
timer.toc()
print('Detection took {:.3f}s for {:d} object proposals'.format(timer.total_time(), boxes.shape[0]))
# 定义阈值
CONF_THRESH = 0.7
NMS_THRESH = 0.3
cls_dets = {}
# NMS 非极大值抑制操作,过滤边界框
for cls_ind, cls in enumerate(CLASSES[1:]):
cls_ind += 1 # 跳过 background
cls_boxes = boxes[:, 4*cls_ind:4*(cls_ind + 1)]
cls_scores = scores[:, cls_ind]
dets = np.hstack((cls_boxes,
cls_scores[:, np.newaxis])).astype(np.float32)
keep = nms(torch.from_numpy(cls_boxes), torch.from_numpy(cls_scores), NMS_THRESH)
dets = dets[keep.numpy(), :]
if len(dets) > 0:
if cls in cls_dets:
cls_dets[cls] = np.vstack([cls_dets[cls], dets])
else:
cls_dets[cls] = dets
vis_detections(im, cls_dets, thresh=CONF_THRESH)
plt.show()
【版权声明】本文为华为云社区用户原创内容,转载时必须标注文章的来源(华为云社区)、文章链接、文章作者等基本信息, 否则作者和本社区有权追究责任。如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱:
cloudbbs@huaweicloud.com
- 点赞
- 收藏
- 关注作者
评论(0)