手把手教物体检测——SSD(新手必看)

举报
AI浩 发表于 2021/12/23 01:27:12 2021/12/23
【摘要】 下载SSD代码: https://github.com/amdegroot/ssd.pytorch 将下载的代码解压后在data文件夹下新建VOCdevkit文件夹,然后将VOC2007数据集复制到该文件夹下面。 下载权重文件放在weights文件夹下面。下载地址: https://s3.amazonaws.com/amdeg...
  1. 下载SSD代码:

https://github.com/amdegroot/ssd.pytorch

  1. 将下载的代码解压后在data文件夹下新建VOCdevkit文件夹,然后将VOC2007数据集复制到该文件夹下面。

  1. 下载权重文件放在weights文件夹下面。下载地址:

https://s3.amazonaws.com/amdegroot-models/vgg16_reducedfc.pth

  1. 修改config.py代码

  
  1. # SSD300 CONFIGS
  2. voc = {
  3.     'num_classes': 3,//将类别改为:类别+1(背景)
  4.     'lr_steps': (80000, 100000, 120000),
  5.     'max_iter': 120000,//迭代次数
  6.     'feature_maps': [38, 19, 10, 5, 3, 1],
  7.     'min_dim': 300,
  8.     'steps': [8, 16, 32, 64, 100, 300],
  9.     'min_sizes': [30, 60, 111, 162, 213, 264],
  10.     'max_sizes': [60, 111, 162, 213, 264, 315],
  11.     'aspect_ratios': [[2], [2, 3], [2, 3], [2, 3], [2], [2]],
  12.     'variance': [0.1, 0.2],
  13.     'clip': True,
  14.     'name': 'VOC',
  15. }
  1. 修改VOC0712.py代码。

  
  1. '''
  2. VOC_CLASSES = (  # always index 0
  3.     'aeroplane', 'bicycle', 'bird', 'boat',
  4.     'bottle', 'bus', 'car', 'cat', 'chair',
  5.     'cow', 'diningtable', 'dog', 'horse',
  6.     'motorbike', 'person', 'pottedplant',
  7.     'sheep', 'sofa', 'train', 'tvmonitor')
  8. '''
  9. VOC_CLASSES = (  # always index 0
  10.     'aircraft', 'oiltank')//修改为自己数据集的类别。
  11. # note: if you used our download scripts, this should be right
  12. #VOC_ROOT = osp.join("", "data/VOCdevkit/")
  13. VOC_ROOT = "data/VOCdevkit/" #修改为Win10路径

 

  
  1. def __init__(self, root,
  2.              image_sets=[('2007', 'trainval'), ('2012', 'trainval')],
  3.              transform=None, target_transform=VOCAnnotationTransform(),
  4.              dataset_name='VOC0712'):
修改为:
 

  
  1. def __init__(self, root,
  2.              image_sets=[('2007', 'trainval')],
  3.              transform=None, target_transform=VOCAnnotationTransform(),
  4.              dataset_name='VOC2007'):
 
 
  1. 修改coco.py

COCO_ROOT ='data/'

  1. 修改ssd.py
将32行的self.cfg = (coco, voc)[num_classes == 21]
 
修改为self.cfg = (coco, voc)[num_classes == 3]#3是我数据集的类别+1
 
  1. 修改multibox_loss.py
在97行的loss_c[pos] = 0  # filter out pos boxes for now
 

位置之上加入loss_c = loss_c.view(pos.size()[0], pos.size()[1])

如果没有这句话会引起张量不匹配的问题
 
如:
 
IndexError: The shape of the mask [4, 8732] at index 0 does not match the shape of the indexed tensor [34928, 1] at index 0
 
  1. 修改全局配置参数:

    parser.add_argument('--dataset', default='VOC', choices=['VOC', 'COCO'],
                       
    type=str, help='VOC or COCO')#改为VOC
    parser.add_argument('--batch_size', default=4, type=int,
                       
    help='Batch size for training')#Batch Size 按照显存的大小设置合适的值。
    parser.add_argument('--start_iter', default=0, type=int,
                       
    help='Resume training at this iter')#迭代开始
    parser.add_argument('--num_workers', default=4, type=int,
                       
    help='Number of workers used in dataloading')
    parser.add_argument('--cuda', default=True, type=str2bool,
                       
    help='Use CUDA to train model')
    parser.add_argument('--lr', '--learning-rate', default=1e-4, type=float,
                       
    help='initial learning rate')
    parser.add_argument('--momentum', default=0.9, type=float,
                       
    help='Momentum value for optim')
    parser.add_argument('--weight_decay', default=5e-4, type=float,
                       
    help='Weight decay for SGD')
    parser.add_argument('--gamma', default=0.1, type=float,
                       
    help='Gamma update for SGD')
    parser.add_argument('--visdom', default=False, type=str2bool,
                       
    help='Use visdom for loss visualization')
    parser.add_argument('--save_folder', default='weights/',
                       
    help='Directory for saving checkpoint models')#保存权重的位置。
    args = parser.parse_args()

    将:
       
     
       
    修改为:
       
    
        
    1. loc_loss += loss_l.item()
    2. conf_loss += loss_c.item()
    3. if iteration % 10 == 0:
    4.     print('timer: %.4f sec.' % (t1 - t0))
    5.     print('iter ' + repr(iteration) + ' || Loss: %.4f ||' % (loss.item()), end=' ')
    不修改会引起错误:
       
    .IndexError: invalid index of a 0-dim tensor. Use tensor.item() to convert a 0-dim tensor to a Python number
       
     
       
    将165行的
       
    images, targets = next(batch_iterator)
       
    更改成:
       
    try:
       
                       images, targets = next(batch_iterator)
       
    except StopIteration:
       
                       batch_iterator = iter(data_loader)
       
                       images, targets = next(batch_iterator)
       
     
       
    如果不改会引起跳出迭代的问题:
       
    Traceback (most recent call last):
       
      File "E:/ssd.pytorch-master/ssd.pytorch-master/train.py", line 246, in <module>
       
        train()
       
      File "E:/ssd.pytorch-master/ssd.pytorch-master/train.py", line 159, in train
       
        images, targets = next(batch_iterator)
       
      File "D:\Users\WH\anaconda3\lib\site-packages\torch\utils\data\dataloader.py", line 345, in __next__
       
        data = self._next_data()
       
      File "D:\Users\WH\anaconda3\lib\site-packages\torch\utils\data\dataloader.py", line 831, in _next_data
       
        raise StopIteration
       
    StopIteration
       
     
       
     
       
    如果遇到loss为nan的现象,就需要降低出示学习率。
       

    https://img-blog.csdnimg.cn/20190324161920234.png

     将10e-3改为10e-4
       
    将以上问题解决后就可以开始训练了。
       
  2. 测试
 
 

需要修改test.py的全局参数。

parser.add_argument('--trained_model', default='weights/ssd300_COCO_60000.pth',
                   
type=str, help='Trained state_dict file path to open')#指定测试用的模型。
parser.add_argument('--save_folder', default='eval/', type=str,
                   
help='Dir to save results')#生成的测试结果
parser.add_argument('--visual_threshold', default=0.6, type=float,
                   
help='Final confidence threshold')#检测的最低置信度。

点击测试生成测试结果:

  
  1. GROUND TRUTH FOR: aircraft_27
  2. label: 650.0 || 101.0 || 753.0 || 227.0 || 0
  3. label: 395.0 || 351.0 || 521.0 || 464.0 || 0
  4. label: 320.0 || 479.0 || 465.0 || 606.0 || 0
  5. label: 276.0 || 617.0 || 432.0 || 753.0 || 0
  6. PREDICTIONS:
  7. 1 label: aircraft score: tensor(1.0000) 317.00345 || 479.0164 || 468.34143 || 604.06415
  8. 2 label: aircraft score: tensor(1.0000) 279.88608 || 614.7385 || 434.32086 || 752.4983
  9. 3 label: aircraft score: tensor(0.9999) 397.0827 || 352.3487 || 524.27655 || 462.45944
  10. 4 label: aircraft score: tensor(0.9990) 647.89233 || 100.57587 || 758.556 || 232.56816
  11. GROUND TRUTH FOR: oiltank_349
  12. label: 115.0 || 340.0 || 237.0 || 466.0 || 1
  13. label: 104.0 || 498.0 || 231.0 || 637.0 || 1
  14. label: 92.0 || 675.0 || 224.0 || 814.0 || 1
  15. label: 369.0 || 378.0 || 492.0 || 507.0 || 1
  16. label: 357.0 || 543.0 || 483.0 || 678.0 || 1
  17. label: 350.0 || 715.0 || 485.0 || 856.0 || 1
  18. label: 534.0 || 378.0 || 662.0 || 511.0 || 1
  19. label: 527.0 || 541.0 || 656.0 || 678.0 || 1
  20. label: 527.0 || 712.0 || 654.0 || 857.0 || 1
  21. PREDICTIONS:
  22. 1 label: oiltank score: tensor(1.0000) 349.36838 || 709.4552 || 481.7989 || 854.8713
  23. 2 label: oiltank score: tensor(1.0000) 528.97327 || 546.3467 || 656.2798 || 676.9037
  24. 3 label: oiltank score: tensor(1.0000) 88.82061 || 667.3125 || 220.63167 || 812.9001
  25. 4 label: oiltank score: tensor(1.0000) 358.38913 || 373.05368 || 489.21268 || 510.5563
  26. 5 label: oiltank score: tensor(1.0000) 519.9066 || 709.3708 || 658.5541 || 863.88947
  27. 6 label: oiltank score: tensor(1.0000) 115.77756 || 339.24872 || 240.00787 || 466.19022
  28. 7 label: oiltank score: tensor(1.0000) 104.77564 || 500.64545 || 230.20764 || 636.5293
  29. 8 label: oiltank score: tensor(1.0000) 357.70694 || 547.2647 || 485.0283 || 679.9714
  30. 9 label: oiltank score: tensor(0.9999) 524.7541 || 375.12167 || 668.46106 || 511.2381
  1. 打印测试结果

import os
import sys
module_path
= os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
   
sys.path.append(module_path)
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
from torch.autograd import Variable
import numpy as np
import cv2
if torch.cuda.is_available():
   
torch.set_default_tensor_type('torch.cuda.FloatTensor')
from ssd import build_ssd
net
= build_ssd('test', 300, 3)    # initialize SSD
net.load_state_dict(torch.load('weights/ssd300_COCO_60000.pth'))#加载模型
net.eval()
image = cv2.imread('data/VOCdevkit/aircraft_27.jpg', cv2.IMREAD_COLOR# 加载图片
from matplotlib import pyplot as plt
rgb_image
= cv2.cvtColor(image, cv2.COLOR_BGR2RGB)#读取图片
# View the sampled input image before transform
x = cv2.resize(image, (300, 300)).astype(np.float32)
x -= (104.0, 117.0, 123.0)
x = x.astype(np.float32)
x = x[:, :, ::-1].copy()
plt.imshow(x)
x = torch.from_numpy(x).permute(2, 0, 1)
xx = Variable(x.unsqueeze(0))     # wrap tensor in Variable
if torch.cuda.is_available():
   
xx = xx.cuda()
y = net(xx)
from data import VOC_CLASSES as labels
top_k
=10#选择前10个结果
plt.figure(figsize=(10,10))
colors = plt.cm.hsv(np.linspace(0, 1, 21)).tolist()
plt.imshow(rgb_image# plot the image for matplotlib
currentAxis = plt.gca()
detections = y.data
# scale each detection back up to the image
scale = torch.Tensor(rgb_image.shape[1::-1]).repeat(2)
for i in range(detections.size(1)):
   
j = 0
   
while detections[0,i,j,0] >= 0.35:
       
score = detections[0,i,j,0]
       
label_name = labels[i-1]
       
display_txt = '%s: %.2f'%(label_name, score)
       
pt = (detections[0,i,j,1:]*scale).cpu().numpy()
       
coords = (pt[0], pt[1]), pt[2]-pt[0]+1, pt[3]-pt[1]+1
       
color = colors[i]
       
currentAxis.add_patch(plt.Rectangle(*coords, fill=False, edgecolor=color, linewidth=2))
       
currentAxis.text(pt[0], pt[1], display_txt, bbox={'facecolor':color, 'alpha':0.5})
       
j+=1
plt.show()

文章来源: wanghao.blog.csdn.net,作者:AI浩,版权归原作者所有,如需转载,请联系作者。

原文链接:wanghao.blog.csdn.net/article/details/105800157

【版权声明】本文为华为云社区用户转载文章,如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱: cloudbbs@huaweicloud.com
  • 点赞
  • 收藏
  • 关注作者

评论(0

0/1000
抱歉,系统识别当前为高风险访问,暂不支持该操作

全部回复

上滑加载中

设置昵称

在此一键设置昵称,即可参与社区互动!

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。