Fast-SCNN(语义分割/Pytorch)
语义分割,也称为全像素语义分割(full-pixel semantic segmentation),是一种典型的计算机视觉问题,其中图像中的每个像素根据其所属的感兴趣对象被分配类别ID。 早期的计算机视觉问题只发现边缘(线条和曲线)或渐变等元素,但它们从未完全按照人类感知的方式提供像素级别的图像理解。语义分割将属于同一目标的图像部分聚集在一起来解决这个问题,从而扩展了其应用领域。Fast-SCNN是一个实时的语义分割模型。其基于现有的two-branch方法(BiSeNet),引入了一个learning to downsample模块,在cityscapes上得到68.0%的miou
本案例是Fast-SCNN论文复现的体验案例,模型基于Fast-SCNN: Fast Semantic Segmentation Network中提出的模型结构实现,会载入预训练模型,训练数据集为Cityscape.
注意事项:
-
本案例使用框架**:** PyTorch1.4.0
-
运行代码方法**:** 点击本页面顶部菜单栏的三角形运行按钮或按Ctrl+Enter键 运行每个方块中的代码
-
JupyterLab的详细用法**:** 请参考《ModelAtrs JupyterLab使用指导》
-
碰到问题的解决办法**:** 请参考《ModelAtrs JupyterLab常见问题解决办法》
1.数据和代码下载
运行下面代码,进行数据和代码的下载和解压缩
本案例使用Cityscape子集,数据位于fast-scnn/datasets中
import os
# 数据代码下载
!wget https://obs-aigallery-zc.obs.cn-north-4.myhuaweicloud.com/algorithm/fast-scnn.zip
# 解压缩
os.system('unzip fast-scnn.zip -d ./')
--2021-06-16 15:28:21-- https://obs-aigallery-zc.obs.cn-north-4.myhuaweicloud.com/algorithm/fast-scnn.zip
Resolving proxy-notebook.modelarts.com (proxy-notebook.modelarts.com)... 192.168.6.62
Connecting to proxy-notebook.modelarts.com (proxy-notebook.modelarts.com)|192.168.6.62|:8083... connected.
Proxy request sent, awaiting response... 200 OK
Length: 2215542147 (2.1G) [application/zip]
Saving to: ‘fast-scnn.zip’
fast-scnn.zip 100%[===================>] 2.06G 356MB/s in 6.2s
2021-06-16 15:28:27 (343 MB/s) - ‘fast-scnn.zip’ saved [2215542147/2215542147]
0
2.模型训练
2.1依赖库安装及加载
import time
import copy
import datetime
import sys
import os
os.system('pip install thop')
os.system('pip install tabulate')
os.system('pip install -U PyYAML')
root_path = './fast-scnn/'
os.chdir(root_path)
import logging
import torch
import torch.nn as nn
import torch.utils.data as data
import torch.nn.functional as F
from torchvision import transforms
from tools.train import *
from segmentron.data.dataloader import get_segmentation_dataset
from segmentron.models.model_zoo import get_segmentation_model
from segmentron.solver.loss import get_segmentation_loss
from segmentron.solver.optimizer import get_optimizer
from segmentron.solver.lr_scheduler import get_scheduler
from segmentron.utils.distributed import *
from segmentron.utils.score import SegmentationMetric
from segmentron.utils.filesystem import save_checkpoint
from segmentron.utils.options import parse_args
from segmentron.utils.default_setup import default_setup
from segmentron.utils.visualize import show_flops_params
from segmentron.utils.visualize import get_color_pallete
from segmentron.config import cfg
import argparse
import matplotlib.pyplot as plt
from PIL import Image
/home/ma-user/anaconda3/envs/Pytorch-1.4.0/lib/python3.6/site-packages/requests/__init__.py:80: RequestsDependencyWarning: urllib3 (1.26.3) or chardet (3.0.4) doesn't match a supported version!
RequestsDependencyWarning)
INFO:root:Using MoXing-v2.0.0.rc0-19e4d3ab
INFO:root:Using OBS-Python-SDK-3.20.9.1
2.2训练参数设置
详细参数设置可以查看 fast-scnn/configs/cityscapes_fast_scnn.yaml 和 fast-scnn/segmentron/config/settings.py
parser = argparse.ArgumentParser(description='Run')
parser.add_argument('--num_nodes', type=int, default=1)
parser.add_argument('--cuda_visiable', type=str, default='0')
parser.add_argument('--config_file', default='./configs/cityscapes_fast_scnn.yaml', help='config file path')
parser.add_argument('--training_dataset', default='/home/ma-user/work/fast-scnn/datasets/', help='Training dataset directory')
# cuda setting
parser.add_argument('--no-cuda', action='store_true', default=False,
help='disables CUDA training')
parser.add_argument('--local_rank', type=int, default=0)
# pre trained
parser.add_argument('--resume', type=str, default='./pre-trained_weights/best_model.pth',
help='put the path to resuming file if needed')
parser.add_argument('--log-iter', type=int, default=10,
help='print log every log-iter')
# for evaluation
parser.add_argument('--val-epoch', type=int, default=1,
help='run validation every val-epoch')
parser.add_argument('--skip-val', action='store_true', default=False,
help='skip validation during training')
args, unknown = parser.parse_known_args()
# get config
cfg.update_from_file(args.config_file)
cfg.PHASE = 'train'
cfg.ROOT_PATH = root_path
cfg.check_and_freeze()
# setup python train environment, logger, seed..
default_setup(args)
2.3开始训练
trainer = Trainer(args)
trainer.train()
args.input_img = './tools/test.png'
3.2加载模型¶
# image transform
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(cfg.DATASET.MEAN, cfg.DATASET.STD),
])
model = get_segmentation_model().to(args.device)
model.eval()
print('模型加载成功')
INFO:root:load pretrained model from ./trained_model/model/best_model.pth
INFO:root:Shape unmatched weights: []
INFO:root:<All keys matched successfully>
模型加载成功
3.3开始测试
if os.path.isdir(args.input_img):
img_paths = [os.path.join(args.input_img, x) for x in os.listdir(args.input_img)]
else:
img_paths = [args.input_img]
for img_path in img_paths:
image = Image.open(img_path).convert('RGB')
images = transform(image).unsqueeze(0).to(args.device)
with torch.no_grad():
output = model(images)
pred = torch.argmax(output[0], 1).squeeze(0).cpu().data.numpy()
mask = get_color_pallete(pred).convert('RGB')
plt.figure(figsize=(20,20)) # 显示缩放比例
plt.subplot(1,2,1)
plt.imshow(image)
plt.subplot(1,2,2)
plt.imshow(mask)
plt.show()
- 点赞
- 收藏
- 关注作者
评论(0)