DynamicRCNN(目标检测/PyTorch)

举报
HWCloudAI 发表于 2022/11/30 17:06:30 2022/11/30
【摘要】 计算机视觉领域最基本的三个任务是:分类、目标定位、目标检测。分类的目标是要识别出给出一张图像是什么类别标签(在训练集中的所有类别标签中,给出的这张图属于那类标签的可能性最大);定位的目标不仅要识别出来是什么物体(类标签),还要给出物体的位置,位置一般用bounding box(边框)标记;目标检测是多个物体的定位,即要在一张图中定位出多个目标物体,目标检测任务包含分类和目标定位。DynamicRC
import moxing as mox
import os
mox.file.copy_parallel('s3://obs-aigallery-zc/algorithm/DynamicRCNN-last','./DynamicRCNN-last')


2.模型训练

2.1依赖库安装及加载


import os
os.system('pip install pycocotools')

root_path = './DynamicRCNN-last/'
os.chdir(root_path)

from PIL import Image, ImageDraw
from tqdm import tqdm
from io import BytesIO
from collections import OrderedDict
import torch.backends.cudnn as cudnn
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import sys
import argparse
import cv2
import numpy as np
import torch

from realTrain import train
from config import config as cfg
from network import Network
from dynamic_rcnn.engine.checkpoint import DetectronCheckpointer
from dynamic_rcnn.engine.comm import synchronize, get_rank, get_world_size,all_gather, is_main_process
from dynamic_rcnn.utils.logger import setup_logger
from dynamic_rcnn.utils.pyt_utils import mkdir, draw_box
from dynamic_rcnn.datasets.structures.image_list import to_image_list
cudnn.benchmark = True


2.2训练参数设置

详细的参数可以看DynamicRCNN-last/config.py

parser = argparse.ArgumentParser(
        description="PyTorch Object Detection Training")

parser.add_argument("--local_rank", type=int, default=0)
parser.add_argument("--training_iter", type=int, default=100)
parser.add_argument("--num_gpus", type=int, default=1)
parser.add_argument("--eval",type=str,default='False')
parser.add_argument("--load_weight",type=str, default='std')
parser.add_argument("--check_period",type=int,default=5000)

## add or alter by BLB
parser.add_argument('--training_dataset', default='./coco_data',
                        help='Training dataset directory')  
parser.add_argument('--save_folder', default='./model',
                        help='Location to save checkpoint models')
parser.add_argument('--data_url',  default='./coco_data', type=str,
                        help='the training and validation data path')


args, unknown = parser.parse_known_args()

2.3开始训练

train(args)
world size: 1

2021-05-24 16:42:46,091 train INFO: Using 1 GPUs
INFO:train:Using 1 GPUs
2021-05-24 16:42:46,094 train INFO: Namespace(check_period=5000, data_url='./coco_data', distributed=False, eval='False', load_weight='std', local_rank=0, num_gpus=1, save_folder='./model', training_dataset='./coco_data', training_iter=100)
INFO:train:Namespace(check_period=5000, data_url='./coco_data', distributed=False, eval='False', load_weight='std', local_rank=0, num_gpus=1, save_folder='./model', training_dataset='./coco_data', training_iter=100)
2021-05-24 16:42:54,256 train INFO: Loading checkpoint from /home/ma-user/work/DynamicRCNN-last/trained_model/model/iteration_270000_mAP_49.2.pth
INFO:train:Loading checkpoint from /home/ma-user/work/DynamicRCNN-last/trained_model/model/iteration_270000_mAP_49.2.pth
2021-05-24 16:42:57,417 train INFO: 加载了本地预训练好的,达到论文精度的best_model.pth进行迁移训练
INFO:train:加载了本地预训练好的,达到论文精度的best_model.pth进行迁移训练
loading annotations into memory...

Done (t=0.91s)

creating index...

index created!

{1: 1, 2: 2, 3: 3, 4: 4, 5: 5, 6: 6, 7: 7, 8: 8, 9: 9, 10: 10, 11: 11, 13: 12, 14: 13, 15: 14, 16: 15, 17: 16, 18: 17, 19: 18, 20: 19, 21: 20, 22: 21, 23: 22, 24: 23, 25: 24, 27: 25, 28: 26, 31: 27, 32: 28, 33: 29, 34: 30, 35: 31, 36: 32, 37: 33, 38: 34, 39: 35, 40: 36, 41: 37, 42: 38, 43: 39, 44: 40, 46: 41, 47: 42, 48: 43, 49: 44, 50: 45, 51: 46, 52: 47, 53: 48, 54: 49, 55: 50, 56: 51, 57: 52, 58: 53, 59: 54, 60: 55, 61: 56, 62: 57, 63: 58, 64: 59, 65: 60, 67: 61, 70: 62, 72: 63, 73: 64, 74: 65, 75: 66, 76: 67, 77: 68, 78: 69, 79: 70, 80: 71, 81: 72, 82: 73, 84: 74, 85: 75, 86: 76, 87: 77, 88: 78, 89: 79, 90: 80}

2021-05-24 16:42:58,405 train INFO: Start training
INFO:train:Start training
2021-05-24 16:42:58,406 train INFO: max_iter:100
INFO:train:max_iter:100
2021-05-24 16:44:35,066 train INFO: eta: 0:06:26  iter: 20  loss: 1.1161 (1.1389)  rpn_cls_loss: 0.0197 (0.0700)  rpn_bbox_loss: 0.0415 (0.0787)  rcnn_cls_loss: 0.6070 (0.6549)  rcnn_bbox_loss: 0.3246 (0.3352)  time: 5.6187 (4.8328)  data: 0.0111 (0.0345)  lr: 0.000200  max mem: 22685
INFO:train:eta: 0:06:26  iter: 20  loss: 1.1161 (1.1389)  rpn_cls_loss: 0.0197 (0.0700)  rpn_bbox_loss: 0.0415 (0.0787)  rcnn_cls_loss: 0.6070 (0.6549)  rcnn_bbox_loss: 0.3246 (0.3352)  time: 5.6187 (4.8328)  data: 0.0111 (0.0345)  lr: 0.000200  max mem: 22685
2021-05-24 16:45:19,010 train INFO: eta: 0:03:30  iter: 40  loss: 0.9624 (1.0859)  rpn_cls_loss: 0.0257 (0.0521)  rpn_bbox_loss: 0.0129 (0.0623)  rcnn_cls_loss: 0.5040 (0.6171)  rcnn_bbox_loss: 0.3610 (0.3545)  time: 0.6644 (3.5150)  data: 0.0101 (0.0231)  lr: 0.000200  max mem: 22685
INFO:train:eta: 0:03:30  iter: 40  loss: 0.9624 (1.0859)  rpn_cls_loss: 0.0257 (0.0521)  rpn_bbox_loss: 0.0129 (0.0623)  rcnn_cls_loss: 0.5040 (0.6171)  rcnn_bbox_loss: 0.3610 (0.3545)  time: 0.6644 (3.5150)  data: 0.0101 (0.0231)  lr: 0.000200  max mem: 22685
2021-05-24 16:46:25,833 train INFO: eta: 0:02:18  iter: 60  loss: 0.9568 (1.0441)  rpn_cls_loss: 0.0139 (0.0471)  rpn_bbox_loss: 0.0165 (0.0556)  rcnn_cls_loss: 0.4574 (0.5867)  rcnn_bbox_loss: 0.3618 (0.3547)  time: 0.6814 (3.4571)  data: 0.0088 (0.0191)  lr: 0.000200  max mem: 22685
INFO:train:eta: 0:02:18  iter: 60  loss: 0.9568 (1.0441)  rpn_cls_loss: 0.0139 (0.0471)  rpn_bbox_loss: 0.0165 (0.0556)  rcnn_cls_loss: 0.4574 (0.5867)  rcnn_bbox_loss: 0.3618 (0.3547)  time: 0.6814 (3.4571)  data: 0.0088 (0.0191)  lr: 0.000200  max mem: 22685
2021-05-24 16:47:02,932 train INFO: eta: 0:01:01  iter: 80  loss: 0.9744 (1.0196)  rpn_cls_loss: 0.0375 (0.0494)  rpn_bbox_loss: 0.0299 (0.0577)  rcnn_cls_loss: 0.4399 (0.5507)  rcnn_bbox_loss: 0.3812 (0.3619)  time: 0.6615 (3.0565)  data: 0.0109 (0.0174)  lr: 0.000200  max mem: 22986
INFO:train:eta: 0:01:01  iter: 80  loss: 0.9744 (1.0196)  rpn_cls_loss: 0.0375 (0.0494)  rpn_bbox_loss: 0.0299 (0.0577)  rcnn_cls_loss: 0.4399 (0.5507)  rcnn_bbox_loss: 0.3812 (0.3619)  time: 0.6615 (3.0565)  data: 0.0109 (0.0174)  lr: 0.000200  max mem: 22986
2021-05-24 16:47:28,373 train INFO: eta: 0:00:00  iter: 100  loss: 0.9252 (1.0421)  rpn_cls_loss: 0.0358 (0.0691)  rpn_bbox_loss: 0.0289 (0.0606)  rcnn_cls_loss: 0.4498 (0.5464)  rcnn_bbox_loss: 0.3488 (0.3661)  time: 0.6337 (2.6996)  data: 0.0111 (0.0164)  lr: 0.000200  max mem: 22986
INFO:train:eta: 0:00:00  iter: 100  loss: 0.9252 (1.0421)  rpn_cls_loss: 0.0358 (0.0691)  rpn_bbox_loss: 0.0289 (0.0606)  rcnn_cls_loss: 0.4498 (0.5464)  rcnn_bbox_loss: 0.3488 (0.3661)  time: 0.6337 (2.6996)  data: 0.0111 (0.0164)  lr: 0.000200  max mem: 22986
2021-05-24 16:47:28,378 train INFO: Saving checkpoint to ./model/checkpoints/model_0000100.pth
INFO:train:Saving checkpoint to ./model/checkpoints/model_0000100.pth
2021-05-24 16:47:29,013 train INFO: Saving checkpoint to ./model/checkpoints/final_model.pth
INFO:train:Saving checkpoint to ./model/checkpoints/final_model.pth
2021-05-24 16:47:31,185 train INFO: Total training time: 0:04:32.776027 (2.7278 s / it)
INFO:train:Total training time: 0:04:32.776027 (2.7278 s / it)

3.模型测试

3.1预测类别

这80类就是COCO数据集中包括的类别

id2name={1: 'person', 2: 'bicycle',3: 'car',4: 'motorcycle',5: 'airplane',6: 'bus',7: 'train',8: 'truck',9: 'boat',10: 'traffic light',11: 'fire hydrant',12: 'stop sign',13: 'parking meter',14: 'bench',
 15: 'bird',16: 'cat',17: 'dog',18: 'horse',19: 'sheep',20: 'cow',21: 'elephant',22: 'bear',23: 'zebra',24: 'giraffe',25: 'backpack',26: 'umbrella',27: 'handbag',28: 'tie',29: 'suitcase',30: 'frisbee',
 31: 'skis',32: 'snowboard',33: 'sports ball',34: 'kite',35: 'baseball bat',36: 'baseball glove',37: 'skateboard',38: 'surfboard',39: 'tennis racket',40: 'bottle',41: 'wine glass',42: 'cup',43: 'fork',44: 'knife',
 45: 'spoon',46: 'bowl',47: 'banana',48: 'apple',49: 'sandwich',50: 'orange',51: 'broccoli',52: 'carrot',53: 'hot dog',54: 'pizza',55: 'donut',56: 'cake',57: 'chair',58: 'couch',59: 'potted plant',
 60: 'bed',61: 'dining table',62: 'toilet',63: 'tv',64: 'laptop',65: 'mouse',66: 'remote',67: 'keyboard',68: 'cell phone',69: 'microwave',70: 'oven',71: 'toaster',72: 'sink',73: 'refrigerator',74: 'book',
 75: 'clock',76: 'vase',77: 'scissors',78: 'teddy bear',79: 'hair drier',80: 'toothbrush'}

3.3展示预测结果

if __name__ == '__main__':
    img_path = 'test.jpg'  # 你可以选择你想测试的图片,并修改路径
    model_path='./trained_model/model/iteration_270000_mAP_49.2.pth'
    img = predict(img_path,model_path)
    plt.figure(figsize=(10,10)) #设置窗口大小
    plt.imshow(img)
    plt.show()
2021-05-24 17:03:46,379 test.inference INFO: Loading checkpoint from ./trained_model/model/iteration_270000_mAP_49.2.pth

2021-05-24 17:03:46,379 test.inference INFO: Loading checkpoint from ./trained_model/model/iteration_270000_mAP_49.2.pth

【版权声明】本文为华为云社区用户原创内容,转载时必须标注文章的来源(华为云社区)、文章链接、文章作者等基本信息, 否则作者和本社区有权追究责任。如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱: cloudbbs@huaweicloud.com
  • 点赞
  • 收藏
  • 关注作者

评论(0

0/1000
抱歉,系统识别当前为高风险访问,暂不支持该操作

全部回复

上滑加载中

设置昵称

在此一键设置昵称,即可参与社区互动!

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。