ModelArts Notebook快速开源项目实战 — U2Net

举报
shpity 发表于 2021/07/22 22:32:10 2021/07/22
【摘要】 本博文主要是对于U2Net在ModelArts Notbook上进行仿真实验。U2Net是一个优秀的显著性目标检测算法,由Qin Xuebin等人发表在Pattern Recognition 2020期刊。U2Net名称的来源在于其网络结构由两层嵌套的Unet结构,可以在不需要预训练骨干网络的情况下从零开始训练。

ModelArts Notebook快速开源项目实战 — U2Net

一、U2Net介绍

U2Net是一个优秀的显著性目标检测算法,由Qin Xuebin等人发表在Pattern Recognition 2020期刊[Arxiv]。U2Net名称的来源在于其网络结构由两层嵌套的Unet结构,可以在不需要预训练骨干网络的情况下从零开始训练,拥有优异的表现。其网络结构如图1所示。
image-20210722100412766

图1. U2Net的主体框架是一个类似于U-Net的编解码结构,但是每一个block替换为新提出的残差U-block模块

项目开源地址:https://github.com/xuebinqin/U-2-Net

二、创建Notebook开发环境

  1. 进入ModelArts控制台

  2. 选择开发环境 -> Notebook -> 创建

  3. 创建Notebook

    3.1 可以选择和任务相关的名称,方便管理;

    3.2 为了减少不必要的资源消耗,建议开启自动停止;

    3.3 U2Net所需的运行环境在公共镜像中已经包含,可以选择pytorch1.4-cuda10.1-cudnn7-ubuntu18.04

    3.4 建议选择GPU类型,方便模型快速训练;

    3.5 选择立即创建 -> 提交,等待notebook创建完成后打开Notebook。

    image-20210722102155943 image-20210722102852009 image-20210722103303867
  4. 导入开源项目源码(git/手动上传)

    4.1 在Terminal使用git克隆远程仓库

    cd work # 注意:只有/home/ma-user/work目录及其子目录下的文件在Notebook实例关闭后会保存
    git clone https://github.com/xuebinqin/U-2-Net.git
    

    4.2 如果git速度较慢也可以从本地上传代码,直接将压缩包拖到左侧文件目录栏或者采用OBS上传。

三、 数据准备

  1. 下载训练数据APDrawing dataset

    使用Wget直接下载到Notebook,也可下载本地后再拖拽到Notebook中。

    wget https://cg.cs.tsinghua.edu.cn/people/~Yongjin/APDrawingDB.zip
    unzip APDrawingDB.zip
    

    注:如果数据集较大(>5GB)需要下载到其它目录(实例停止后会被删除),建议存放在OBS中,需要的时候随时拉取。

    #从OBS中拉取代码到指定目录
    sh-4.4$ source /home/ma-user/anaconda3/bin/activate PyTorch-1.4
    sh-4.4$ python
    >>> mox.file.copy_parallel('obs://bucket-xxxx/APDrawingDB', '/home/ma-user/work/APDrawingDB')
    
  2. 切分训练数据

    数据集中./APDrawingDB/data/train中包含了420张训练图片,分辨率为512*1024,左侧为输入图像,右侧为对应的ground truth。我们需要将大图从中间切分为两个子图。

    img_1586

    2.1 在Notebook开发环境中新建一个Pytorch-1.4的jupyter Notebook文件,名称可以为split.ipynb,脚本将会在./APDrawingDB/data/train/split目录下生成840张子图,其中原始图像以.jpg结尾,gt图像以.png结尾,方便后续训练代码读取【test文件夹切分步骤同理】。

    from PIL import Image
    import os
    train_img_dir = os.path.join("./APDrawingDB/data/train")
    img_list = os.listdir(train_img_dir)
    for image in img_list:
        img_path = os.path.join(train_img_dir, image)
        if not os.path.isdir(img_path):
            img = Image.open(img_path)
            #print(img.size)
            save_img_dir = os.path.join(train_img_dir, 'split_train')
            if not os.path.exists(save_img_dir):
                os.mkdir(save_img_dir)
            save_img_path = os.path.join(save_img_dir, image)
            cropped_left = img.crop((0, 0, 512, 512))  # (left, upper, right, lower)
            cropped_right = img.crop((512, 0, 1024, 512))  # (left, upper, right, lower)
            cropped_left.save(save_img_path[:-3] + 'jpg')
            cropped_right.save(save_img_path)
    
    
    test_img_dir = os.path.join("./APDrawingDB/data/test")
    img_list = os.listdir(test_img_dir)
    for image in img_list:
        img_path = os.path.join(test_img_dir, image)
        if not os.path.isdir(img_path):
            img = Image.open(img_path)
            #print(img.size)
            save_img_dir = os.path.join(test_img_dir, 'split')
            if not os.path.exists(save_img_dir):
                os.mkdir(save_img_dir)
            save_img_path = os.path.join(save_img_dir, image)
            cropped_left = img.crop((0, 0, 512, 512))  # (left, upper, right, lower)
            cropped_right = img.crop((512, 0, 1024, 512))  # (left, upper, right, lower)
            cropped_left.save(save_img_path[:-3] + 'jpg')
    
  3. 将切分好的数据按照如下层级结构整理出训练和测试所需的datasets文件夹

    datasets/
    ├── test (70张切分图片,只包含原图)
    └── train (840张切分图片,包含420张原图及对应的gt)

    注:可以将切分好的数据集保存到OBS目录中,减少./work的磁盘空间占用。

  4. 完整的U-2-Net项目结构如下所示:

    U-2-Net/
    ├── .git
    ├── LICENSE
    ├── README.md
    ├── pycache
    ├── clipping_camera.jpg
    ├── data_loader.py
    ├── datasets
    ├── figures
    ├── gradio
    ├── model
    ├── requirements.txt
    ├── saved_models
    ├── setup_model_weights.py
    ├── test_data
    ├── u2net_human_seg_test.py
    ├── u2net_portrait_demo.py
    ├── u2net_portrait_test.py
    ├── u2net_test.py
    └── u2net_train.py

四、训练

  1. 官方提供的训练代码中数据的路径和我们的datasets有些区别,需要对训练脚本进行一些修改,建议使用jupyter notebook方便排除错误

    新建一个Pytorch-1.4的jupyter Notebook文件,名称可以为train.ipynb

    import moxing as mox
    # 如果需要从OBS拷贝切分好的训练数据
    #mox.file.copy_parallel('obs://bucket-test-xxxx', '/home/ma-user/work/U-2-Net/datasets')
    
    INFO:root:Using MoXing-v1.17.3-43fbf97f
    INFO:root:Using OBS-Python-SDK-3.20.7
    
    import os
    import torch
    import torchvision
    from torch.autograd import Variable
    import torch.nn as nn
    import torch.nn.functional as F
    
    from torch.utils.data import Dataset, DataLoader
    from torchvision import transforms, utils
    import torch.optim as optim
    import torchvision.transforms as standard_transforms
    
    import numpy as np
    import glob
    import os
    
    from data_loader import Rescale
    from data_loader import RescaleT
    from data_loader import RandomCrop
    from data_loader import ToTensor
    from data_loader import ToTensorLab
    from data_loader import SalObjDataset
    
    from model import U2NET
    from model import U2NETP
    
    /home/ma-user/anaconda3/envs/PyTorch-1.4/lib/python3.7/site-packages/skimage/io/manage_plugins.py:23: UserWarning: Your installed pillow version is < 7.1.0. Several security issues (CVE-2020-11538, CVE-2020-10379, CVE-2020-10994, CVE-2020-10177) have been fixed in pillow 7.1.0 or higher. We recommend to upgrade this library.
    from .collection import imread_collection_wrapper
    
    bce_loss = nn.BCELoss(size_average=True)
    
    /home/ma-user/anaconda3/envs/PyTorch-1.4/lib/python3.7/site-packages/torch/nn/_reduction.py:43: UserWarning: size_average and reduce args will be deprecated, please use reduction='mean' instead.
    warnings.warn(warning.format(ret))
    
    def muti_bce_loss_fusion(d0, d1, d2, d3, d4, d5, d6, labels_v):
    
    	loss0 = bce_loss(d0,labels_v)
    	loss1 = bce_loss(d1,labels_v)
    	loss2 = bce_loss(d2,labels_v)
    	loss3 = bce_loss(d3,labels_v)
    	loss4 = bce_loss(d4,labels_v)
    	loss5 = bce_loss(d5,labels_v)
    	loss6 = bce_loss(d6,labels_v)
    
    	loss = loss0 + loss1 + loss2 + loss3 + loss4 + loss5 + loss6
    	print("l0: %3f, l1: %3f, l2: %3f, l3: %3f, l4: %3f, l5: %3f, l6: %3f\n"%(loss0.data.item(),loss1.data.item(),loss2.data.item(),loss3.data.item(),loss4.data.item(),loss5.data.item(),loss6.data.item()))
    
    	return loss0, loss
    
    model_name = 'u2net' #'u2netp'
    
    data_dir = os.path.join(os.getcwd(), 'datasets', 'train' + os.sep)
    # tra_image_dir = os.path.join('DUTS', 'DUTS-TR', 'DUTS-TR', 'im_aug' + os.sep)
    # tra_label_dir = os.path.join('DUTS', 'DUTS-TR', 'DUTS-TR', 'gt_aug' + os.sep)
    
    image_ext = '.jpg'
    label_ext = '.png'
    
    model_dir = os.path.join(os.getcwd(), 'saved_models', model_name + os.sep)
    
    epoch_num = 100000
    batch_size_train = 24
    batch_size_val = 1
    train_num = 0
    val_num = 0
    
    tra_img_name_list = glob.glob(data_dir  + '*' + image_ext)
    
    tra_lbl_name_list = []
    for img_path in tra_img_name_list:
    	img_name = img_path.split(os.sep)[-1]
    
    	aaa = img_name.split(".")
    	bbb = aaa[0:-1]
    	imidx = bbb[0]
    	for i in range(1,len(bbb)):
    		imidx = imidx + "." + bbb[i]
    
    	tra_lbl_name_list.append(data_dir  + imidx + label_ext)
    
    print("---")
    print("train images: ", len(tra_img_name_list))
    print("train labels: ", len(tra_lbl_name_list))
    print("---")
    
    train_num = len(tra_img_name_list)
    
    ---
    train images:  420
    train labels:  420
    ---
    
    salobj_dataset = SalObjDataset(
        img_name_list=tra_img_name_list,
        lbl_name_list=tra_lbl_name_list,
        transform=transforms.Compose([
            RescaleT(320),
            RandomCrop(288),
            ToTensorLab(flag=0)]))
    salobj_dataloader = DataLoader(salobj_dataset, batch_size=batch_size_train, shuffle=True, num_workers=1)
    
    # ------- 3. define model --------
    # define the net
    if(model_name=='u2net'):
        net = U2NET(3, 1)
    elif(model_name=='u2netp'):
        net = U2NETP(3,1)
    
    if torch.cuda.is_available():
        net.cuda()
    
    # ------- 4. define optimizer --------
    print("---define optimizer...")
    optimizer = optim.Adam(net.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)
    
    ---define optimizer...
    
    # ------- 5. training process --------
    print("---start training...")
    ite_num = 0
    running_loss = 0.0
    running_tar_loss = 0.0
    ite_num4val = 0
    save_frq = 2000 # save the model every 2000 iterations
    
    ---start training...
    
    for epoch in range(0, epoch_num):
        net.train()
    
        for i, data in enumerate(salobj_dataloader):
            ite_num = ite_num + 1
            ite_num4val = ite_num4val + 1
    
            inputs, labels = data['image'], data['label']
    
            inputs = inputs.type(torch.FloatTensor)
            labels = labels.type(torch.FloatTensor)
    
            # wrap them in Variable
            if torch.cuda.is_available():
                inputs_v, labels_v = Variable(inputs.cuda(), requires_grad=False), Variable(labels.cuda(),
                                                                                            requires_grad=False)
            else:
                inputs_v, labels_v = Variable(inputs, requires_grad=False), Variable(labels, requires_grad=False)
    
            # y zero the parameter gradients
            optimizer.zero_grad()
    
            # forward + backward + optimize
            d0, d1, d2, d3, d4, d5, d6 = net(inputs_v)
            loss2, loss = muti_bce_loss_fusion(d0, d1, d2, d3, d4, d5, d6, labels_v)
    
            loss.backward()
            optimizer.step()
    
            # # print statistics
            running_loss += loss.data.item()
            running_tar_loss += loss2.data.item()
    
            # del temporary outputs and loss
            del d0, d1, d2, d3, d4, d5, d6, loss2, loss
            print("[epoch: %3d/%3d, batch: %5d/%5d, ite: %d] train loss: %3f, tar: %3f " % (
            epoch + 1, epoch_num, (i + 1) * batch_size_train, train_num, ite_num, running_loss / ite_num4val, running_tar_loss / ite_num4val))
    
            if ite_num % save_frq == 0:
                model_weight = model_dir + model_name+"_bce_itr_%d_train_%3f_tar_%3f.pth" % (ite_num, running_loss / ite_num4val, running_tar_loss / ite_num4val)
                torch.save(net.state_dict(),  model_weight)
                mox.file.copy_parallel(model_weight, 'obs://bucket-xxxx/output/model_save/' + model_weight.split('/')[-1])
                running_loss = 0.0
                running_tar_loss = 0.0
                net.train()  # resume train
                ite_num4val = 0
    
    l0: 0.167562, l1: 0.153742, l2: 0.156246, l3: 0.163096, l4: 0.176632, l5: 0.197176, l6: 0.247590
    
    [epoch:   1/100000, batch:    24/  420, ite: 500] train loss: 1.189413, tar: 0.159183 
    l0: 0.188048, l1: 0.179041, l2: 0.180086, l3: 0.187904, l4: 0.198345, l5: 0.218509, l6: 0.269199
    
    [epoch:   1/100000, batch:    48/  420, ite: 501] train loss: 1.266652, tar: 0.168805 
    l0: 0.192491, l1: 0.187615, l2: 0.188043, l3: 0.197142, l4: 0.203571, l5: 0.222019, l6: 0.261745
    
    [epoch:   1/100000, batch:    72/  420, ite: 502] train loss: 1.313146, tar: 0.174727 
    l0: 0.169403, l1: 0.155883, l2: 0.157974, l3: 0.164012, l4: 0.175975, l5: 0.195938, l6: 0.244896
    
    [epoch:   1/100000, batch:    96/  420, ite: 503] train loss: 1.303333, tar: 0.173662 
    l0: 0.171904, l1: 0.157170, l2: 0.156688, l3: 0.162020, l4: 0.175565, l5: 0.200576, l6: 0.258133
    
    [epoch:   1/100000, batch:   120/  420, ite: 504] train loss: 1.299787, tar: 0.173369 
    l0: 0.177398, l1: 0.166131, l2: 0.169089, l3: 0.176976, l4: 0.187039, l5: 0.205449, l6: 0.248036
    

五、测试

新建一个Pytorch-1.4的jupyter Notebook文件,名称可以为test.ipynb

import moxing as mox
# 拷贝数据
mox.file.copy_parallel('obs://bucket-xxxx/output/model_save/u2net.pth', '/home/ma-user/work/U-2-Net/saved_models/u2net/u2net.pth')
import os
import sys
from skimage import io, transform
import torch
import torchvision
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms#, utils
# import torch.optim as optim

import numpy as np
from PIL import Image
import glob

from data_loader import RescaleT
from data_loader import ToTensor
from data_loader import ToTensorLab
from data_loader import SalObjDataset

from model import U2NET # full size version 173.6 MB
from model import U2NETP # small version u2net 4.7 MB
%matplotlib inline
import matplotlib.pyplot as plt
from matplotlib.pyplot import imshow
# normalize the predicted SOD probability map
def normPRED(d):
    ma = torch.max(d)
    mi = torch.min(d)

    dn = (d-mi)/(ma-mi)

    return dn

def save_output(image_name,pred,d_dir, show=False):

    predict = pred
    predict = predict.squeeze()
    predict_np = predict.cpu().data.numpy()

    im = Image.fromarray(predict_np*255).convert('RGB')
    img_name = image_name.split(os.sep)[-1]
    image = io.imread(image_name)
    imo = im.resize((image.shape[1],image.shape[0]),resample=Image.BILINEAR)

    pb_np = np.array(imo)
    if show:
        show_on_notebook(image, im)
    aaa = img_name.split(".")
    bbb = aaa[0:-1]
    imidx = bbb[0]
    for i in range(1,len(bbb)):
        imidx = imidx + "." + bbb[i]

    imo.save(d_dir+imidx+'.png')
    return im
    
def show_on_notebook(image_original, pred): #此函数可以在notebook中展示模型的预测效果
    plt.subplot(1,2,1)
    imshow(np.array(image_original))
    plt.subplot(1,2,2)
    imshow(np.array(pred))
    

# --------- 1. get image path and name ---------
model_name='u2net'#u2netp

image_dir = os.path.join(os.getcwd(), 'datasets', 'test') #注意这里的test_data/original存放的是datasets/test中的原始图片,不包含gt
prediction_dir = os.path.join(os.getcwd(), 'output', model_name + '_results' + os.sep)
model_dir = os.path.join(os.getcwd(), 'saved_models', model_name, model_name + '.pth')


img_name_list = glob.glob(os.path.join(os.getcwd(), 'datasets/test/*.jpg'))
# print(img_name_list)

# --------- 2. dataloader ---------
#1. dataloader
test_salobj_dataset = SalObjDataset(img_name_list = img_name_list,
                                    lbl_name_list = [],
                                    transform=transforms.Compose([RescaleT(320),
                                                                  ToTensorLab(flag=0)])
                                    )
test_salobj_dataloader = DataLoader(test_salobj_dataset,
                                    batch_size=1,
                                    shuffle=False,
                                    num_workers=1)

# --------- 3. model define ---------
if(model_name=='u2net'):
    print("...load U2NET---173.6 MB")
    net = U2NET(3,1)
elif(model_name=='u2netp'):
    print("...load U2NEP---4.7 MB")
    net = U2NETP(3,1)

if torch.cuda.is_available():
    net.load_state_dict(torch.load(model_dir))
    net.cuda()
else:
    net.load_state_dict(torch.load(model_dir, map_location='cpu'))
net.eval()

# --------- 4. inference for each image ---------
for i_test, data_test in enumerate(test_salobj_dataloader):

#     print("inferencing:",img_name_list[i_test].split(os.sep)[-1])

    inputs_test = data_test['image']
    
    inputs_test = inputs_test.type(torch.FloatTensor)

    if torch.cuda.is_available():
        inputs_test = Variable(inputs_test.cuda())
    else:
        inputs_test = Variable(inputs_test)

    d1,d2,d3,d4,d5,d6,d7= net(inputs_test)

    # normalization
    pred = d1[:,0,:,:]
    pred = normPRED(pred)

    # save results to test_results folder
    if not os.path.exists(prediction_dir):
        os.makedirs(prediction_dir, exist_ok=True)
    save_output(img_name_list[i_test],pred,prediction_dir, show=True)
#     sys.exit(0)

    del d1,d2,d3,d4,d5,d6,d7

image-20210722134602040

六、附件

见附件

想了解更多的AI技术干货,欢迎上华为云的AI专区,目前有AI编程Python等六大实战营供大家免费学习。(六大实战营link:http://su.modelarts.club/qQB9)

    附件下载

  • 附件.zip 71.28KB 下载次数:32
【版权声明】本文为华为云社区用户原创内容,转载时必须标注文章的来源(华为云社区)、文章链接、文章作者等基本信息, 否则作者和本社区有权追究责任。如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱: cloudbbs@huaweicloud.com
  • 点赞
  • 收藏
  • 关注作者

评论(0

0/1000
抱歉,系统识别当前为高风险访问,暂不支持该操作

全部回复

上滑加载中

设置昵称

在此一键设置昵称,即可参与社区互动!

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。