物体检测-Faster R-CNN(2)

举报
可爱又积极 发表于 2021/07/09 10:59:43 2021/07/09
【摘要】 测试部分在这部分中,我们利用训练得到的模型进行推理测试。In [13]:%matplotlib inlinefrom __future__ import absolute_importfrom __future__ import divisionfrom __future__ import print_function# 将路径转入libimport tools._init_pathsfro...

测试部分

在这部分中,我们利用训练得到的模型进行推理测试。

In [13]:
%matplotlib inline
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# 将路径转入lib
import tools._init_paths

from model.config import cfg
from model.test import im_detect
from torchvision.ops import nms

from utils.timer import Timer
import matplotlib.pyplot as plt
import numpy as np
import os, cv2
import argparse

from nets.vgg16 import vgg16
from nets.resnet_v1 import resnetv1
from model.bbox_transform import clip_boxes, bbox_transform_inv

import torch

参数定义

In [14]:
# PASCAL VOC类别设置
CLASSES = ('__background__',
           'aeroplane', 'bicycle', 'bird', 'boat',
           'bottle', 'bus', 'car', 'cat', 'chair',
           'cow', 'diningtable', 'dog', 'horse',
           'motorbike', 'person', 'pottedplant',
           'sheep', 'sofa', 'train', 'tvmonitor')
# 网络模型文件名定义
NETS = {'vgg16': ('vgg16_faster_rcnn_iter_%d.pth',),'res101': ('res101_faster_rcnn_iter_%d.pth',)}
# 数据集文件名定义
DATASETS= {'pascal_voc': ('voc_2007_trainval',),'pascal_voc_0712': ('voc_2007_trainval+voc_2012_trainval',)}

结果绘制

将预测的标签和边界框绘制在原图上。

In [15]:
def vis_detections(im, class_dets, thresh=0.5):
    """Draw detected bounding boxes."""
    im = im[:, :, (2, 1, 0)]
    fig, ax = plt.subplots(figsize=(12, 12))
    ax.imshow(im, aspect='equal')
    for class_name in class_dets:
        dets = class_dets[class_name]
        inds = np.where(dets[:, -1] >= thresh)[0]
        if len(inds) == 0:
            continue
        
        for i in inds:
            bbox = dets[i, :4]
            score = dets[i, -1]

            ax.add_patch(
                plt.Rectangle((bbox[0], bbox[1]),
                              bbox[2] - bbox[0],
                              bbox[3] - bbox[1], fill=False,
                              edgecolor='red', linewidth=3.5)
                )
            ax.text(bbox[0], bbox[1] - 2,
                    '{:s} {:.3f}'.format(class_name, score),
                    bbox=dict(facecolor='blue', alpha=0.5),
                    fontsize=14, color='white')

        plt.axis('off')
        plt.tight_layout()
        plt.draw()

准备测试图片

我们将测试图片传到test文件夹下,我们准备了两张图片进行测试,大家也可以通过notebook的upload按钮上传自己的测试数据。注意,测试数据需要是图片,并且放在test文件夹下。

In [16]:
test_file = "./test"

模型推理

这里我们加载一个预先训练好的模型,也可以选择案例中训练的模型。

In [17]:
import cv2
from utils.timer import Timer
from model.test import im_detect
from torchvision.ops import nms

cfg.TEST.HAS_RPN = True  # Use RPN for proposals

# 模型存储位置
# 这里我们加载一个已经训练110000迭代之后的模型,可以选择自己的训练模型位置
saved_model = "./models/vgg16-voc0712/vgg16_faster_rcnn_iter_110000.pth"
print('trying to load weights from ', saved_model)

# 加载backbone
net = vgg16()
# 构建网络
net.create_architecture(21, tag='default', anchor_scales=[8, 16, 32])
# 加载权重文件
net.load_state_dict(torch.load(saved_model, map_location=lambda storage, loc: storage))

net.eval()
# 选择推理设备
net.to(net._device)

print('Loaded network {:s}'.format(saved_model))

for file in os.listdir(test_file):
    if file.startswith("._") == False:
        file_path = os.path.join(test_file, file)
        print(file_path)
        # 打开测试图片文件
        im = cv2.imread(file_path)

        # 定义计时器
        timer = Timer()
        timer.tic()
        # 检测得到图片ROI
        scores, boxes = im_detect(net, im)

        print(scores.shape, boxes.shape)
        timer.toc()
        print('Detection took {:.3f}s for {:d} object proposals'.format(timer.total_time(), boxes.shape[0]))

        # 定义阈值
        CONF_THRESH = 0.7
        NMS_THRESH = 0.3

        cls_dets = {}

        # NMS 非极大值抑制操作,过滤边界框
        for cls_ind, cls in enumerate(CLASSES[1:]):
            cls_ind += 1 # 跳过 background
            cls_boxes = boxes[:, 4*cls_ind:4*(cls_ind + 1)]
            cls_scores = scores[:, cls_ind]
            dets = np.hstack((cls_boxes,
                              cls_scores[:, np.newaxis])).astype(np.float32)
            keep = nms(torch.from_numpy(cls_boxes), torch.from_numpy(cls_scores), NMS_THRESH)
            dets = dets[keep.numpy(), :]

            if len(dets) > 0:
                if cls in cls_dets:
                    cls_dets[cls] = np.vstack([cls_dets[cls], dets]) 
                else:
                    cls_dets[cls] = dets
        vis_detections(im, cls_dets, thresh=CONF_THRESH)

        plt.show()

trying to load weights from  ./models/vgg16-voc0712/vgg16_faster_rcnn_iter_110000.pth
Loaded network ./models/vgg16-voc0712/vgg16_faster_rcnn_iter_110000.pth
./test/test_image_0.jpg
(300, 21) (300, 84)
Detection took 0.062s for 300 object proposals


./test/test_image_1.jpg
(300, 21) (300, 84)
Detection took 0.054s for 300 object proposals

【版权声明】本文为华为云社区用户原创内容,转载时必须标注文章的来源(华为云社区)、文章链接、文章作者等基本信息, 否则作者和本社区有权追究责任。如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱: cloudbbs@huaweicloud.com
  • 点赞
  • 收藏
  • 关注作者

评论(0

0/1000
抱歉,系统识别当前为高风险访问,暂不支持该操作

全部回复

上滑加载中

设置昵称

在此一键设置昵称,即可参与社区互动!

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。