【华为云-上云之路】【昇腾】【每天进步一点点】如何基于官方例程开发自己想实现的项目(Python版本)

举报
Tianyi_Li 发表于 2020/05/31 21:44:04 2020/05/31
【摘要】 很多时候,基于各种需要,我们希望能在Atlas平台上开发自己想要的项目,虽然官方给出了很多例程项目,但很多时候并不是我们想要的效果,可能需要更换输入方法,比如官方例程给的是摄像头输入,我们希望是本地视频输入等等问题,这个时候就需要自己去开发实现了,不过,做好是基于官方项目例程去实现,尽量不要从零开始,自己去开发,因为例程是经过专业工程师调试的。这里基于分类项目修改实现黑白图像上色。

前言

很多时候,基于各种需要,我们希望能在Atlas平台上开发自己想要的项目,虽然官方给出了很多例程项目,涉及分类、目标检测、分割等常用场景,但很多时候并不是我们想要的效果,可能需要更换输入方法,比如官方例程给的是摄像头输入,我们希望是本地视频输入等等问题,这个时候就需要自己去开发实现了,不过,做好是基于官方项目例程去实现,尽量不要从零开始,自己去开发,因为例程是经过专业工程师调试的,可以确保能在设备上运行起来,如果自己从头开发,可能会遇到很多工程师已经走过的坑,开发效率就比较低了。站在巨人的肩膀上,基于现成的案例去开发,避免重复造轮子。

开发案例

这里基于社区例程sample-classification-python(链接为https://gitee.com/Atlas200DK/sample-classification-python/tree/master)实现黑白图像上色,其实黑白图像上色这个项目是官方例程中的一个,记得以前有Python版本,现在好像找不到了,无所谓了,现在来试试吧。

1. 模型转换

可以说这是整个项目的核心部分了,首先要的到用于黑白图像上色的.om模型,这个可以参考案例找到模型,并按照要求转换https://gitee.com/Atlas200DK/sample-README/tree/master/sample-colorization

2. 改造代码

首先,sample-classification-python比较简单,使用了一个模型,而且主要代码在classify.py中实现了,代码中明确指定了模型和输入图像的路径,我们只要相应修改就行了,把模型换为刚才转化的图像上色的模型,图像输入路径放上待上色的图片就行了,注意这里图片要是.jpg结尾的图片,因为在代码中是通过.jpg结尾来确定图片,并输入模型的,这要注意。

其次,就是模型输入尺寸大小了,这一点根据模型转换时设置的尺寸做修改。

最后,是后处理,这里很简单,直接写为图片就行了,但也因为如此,程序运行较慢,写图片比较费时间,而且使用OpenCV写的图片好像质量一般。

至此,大功告成。看看修改好的代码吧。

# coding=utf-8

import hiai
import imageNetClasses
import os
import numpy as np
import time
import graph
import post_process
import cv2

cur_path = os.path.dirname(os.path.realpath(__file__))
os.chdir(cur_path)
# resnet18OmFileName='./models/resnet18.om'
resnet18OmFileName='./models/colorization.om'
srcFileDir = './ImageNetRaw/'
dstFileDir = './resnet18Result/'


# 初始参数设定
resize_w = 224
resize_h = 224
out_w = 56
out_h = 56


def preprocess(img_bgr):
    bgr_img = img_bgr.astype(np.float32)
    orig_shape = bgr_img.shape[:2]
    bgr_img = bgr_img / 255.0
    lab_img = cv2.cvtColor(bgr_img, cv2.COLOR_BGR2Lab)
    orig_l = lab_img[:,:,0]
    if not orig_l.flags['C_CONTIGUOUS']:
        orig_l = np.ascontiguousarray(orig_l)
    lab_img = cv2.resize(lab_img, (resize_w, resize_h)).astype(np.float32)
    l_data = lab_img[:,:,0]
    if not l_data.flags['C_CONTIGUOUS']:
        l_data = np.ascontiguousarray(l_data)
    l_data = l_data - 50
    return orig_shape, orig_l, l_data


def postprocess(result_list, orig_shape, orig_l):
    result_array = result_list[0][0]
    ab_data = cv2.resize(result_array, orig_shape[::-1])
    result_lab = np.concatenate((orig_l[:,:,np.newaxis],ab_data),axis=2)
    result_bgr = (255*np.clip(cv2.cvtColor(result_lab, cv2.COLOR_Lab2BGR),0,1)).astype('uint8')
    # file_name = os.path.join(output_path, 'out_'+pic)
    #cv.imwrite(file_name, result_bgr)
    
    # result_rgb = cv2.cvtColor(result_bgr, cv2.COLOR_RGB2BGR) # 转为RGB格式
    result_rgb = result_bgr
    return result_rgb


def Resnet18PostProcess(resultList, srcFilePath, dstFilePath, fileName, orig_shape, orig_l):
        if resultList is not None :
                # firstConfidence, firstClass = post_process.GenerateTopNClassifyResult(resultList, 1)
                # firstLabel = imageNetClasses.imageNet_classes[firstClass[0]]
                dstFileName = os.path.join(dstFilePath, fileName)
                # srcFileName = os.path.join(srcFilePath, fileName)
                # image = cv.imread(srcFileName)
                # txt = firstLabel + " " + str(round(firstConfidence[0]*100,2))
                # cv.putText(image, txt, (15,20), cv.FONT_HERSHEY_COMPLEX_SMALL, 1.0, (0, 0, 255))
                resultList[0] = resultList[0].reshape(1,2,56,56).transpose(0,2,3,1)
                img_rgb = postprocess(resultList, orig_shape, orig_l)
		
		
                cv2.imwrite(dstFileName, img_rgb)
        else :
                print('graph inference failed ')
                return None


def main():
        try:
                myGraph = graph.Graph(resnet18OmFileName)
                myGraph.CreateGraph()
        except Exception as e:
                print("Except:", e)
                return
        # dvppInWidth = 224
        # dvppInHeight = 224
        start = time.time()
        if not os.path.exists(dstFileDir):
                os.mkdir(dstFileDir)
        pathDir = os.listdir(srcFileDir)
        for allDir in pathDir:
                child = os.path.join(srcFileDir, allDir)
                img_bgr = cv2.imread(child)
                # input_image = cv2.resize(input_image, (dvppInWidth, dvppInHeight))
                orig_shape, orig_l, l_data = preprocess(img_bgr) # 缩放为模型输入尺寸
                resultList = myGraph.Inference(l_data)
                if resultList is None:
                        print("graph inference failed")
                        continue
                Resnet18PostProcess(resultList, srcFileDir, dstFileDir, allDir, orig_shape, orig_l)
                
        end = time.time()
        print('cost time '+str((end-start)*1000)+'ms')
        myGraph.Destroy()
        print('-------------------end')


if __name__ == "__main__":
        main()


再来看看原来的代码

#coding=utf-8

import hiai
import imageNetClasses
import os
import numpy as np
import time
import graph
import post_process
import cv2 as cv

cur_path = os.path.dirname(os.path.realpath(__file__))
os.chdir(cur_path)
resnet18OmFileName='./models/resnet18.om'
srcFileDir = './ImageNetRaw/'
dstFileDir = './resnet18Result/'


def Resnet18PostProcess(resultList, srcFilePath, dstFilePath, fileName):
        if resultList is not None :
                firstConfidence, firstClass = post_process.GenerateTopNClassifyResult(resultList, 1)
                firstLabel = imageNetClasses.imageNet_classes[firstClass[0]]
                dstFileName = os.path.join(dstFilePath, fileName)
                srcFileName = os.path.join(srcFilePath, fileName)
                image = cv.imread(srcFileName)
                txt = firstLabel + " " + str(round(firstConfidence[0]*100,2))
                cv.putText(image, txt, (15,20), cv.FONT_HERSHEY_COMPLEX_SMALL, 1.0, (0, 0, 255))
                cv.imwrite(dstFileName, image)
        else :
                print('graph inference failed ')
                return None


def main():
        try:
                myGraph = graph.Graph(resnet18OmFileName)
                myGraph.CreateGraph()
        except Exception as e:
                print("Except:", e)
                return
        dvppInWidth = 224
        dvppInHeight = 224
        start = time.time()
        if not os.path.exists(dstFileDir):
                os.mkdir(dstFileDir)
        pathDir = os.listdir(srcFileDir)
        for allDir in pathDir:
                child = os.path.join(srcFileDir, allDir)
                input_image = cv.imread(child)
                input_image = cv.resize(input_image, (dvppInWidth, dvppInHeight))
                resultList = myGraph.Inference(input_image)
                if resultList is None:
                        print("graph inference failed")
                        continue
                Resnet18PostProcess(resultList, srcFileDir, dstFileDir, allDir)
        end = time.time()
        print('cost time '+str((end-start)*1000)+'ms')
        myGraph.Destroy()
        print('-------------------end')


if __name__ == "__main__":
        main()


对比可见,修改的还是比较少的。在官方例程上修改,还是比较轻松的,可以快速实现。当然,如果需要工业部署,可能要用C++开发了,C++比Python会难一些,所以最好在官方例程上修改,这样可以省下很多时间,提高开发效率。

下面看看运行结果

image.png


最终效果展示:

1. 素描

image.png

2. 近景

image.png

3. 人物

image.png

image.png


4. 自然风光

image.png


image.png


image.png


放大看细节,你会发现上色效果非常好,光线感自然,仿佛浑然天成,特别是右下角的森林,层次感非常好,仿佛流动的绿色海洋。这张图达到了1.33MB,未上色的原图是1.77MB,细节表现的很好。

test03.jpg


最后,展示一段上色的老视频,原视频为一段航拍老北京的视频,通过视频可以看到上色效果并不是很好,这是因为原视频清晰度较差,分辨率低,特别是画面中物体边界不明显。一般情况下,对老视频上色,首先要做补帧和提高分辨率,或者称为超分辨技术,提升分辨率,使得画面景物边界明显,否则模糊的边界对上色影响很大。


                                                                 


最后,奉上全部代码(包括上述测试文件),下载解压后,拷贝到Atlas 200 DK上,运行classificationapp文件夹下的classify.py文件即可,即执行命令

python3 classify.py

# Python2可能也行,没试过。

等待即可,即可得到前面的结果。注意时间有点长,可以先干点其他事情。

链接:https://pan.baidu.com/s/15lIpjBJ94-peq0-oB-PHXg 

提取码:ocj3




【版权声明】本文为华为云社区用户原创内容,转载时必须标注文章的来源(华为云社区)、文章链接、文章作者等基本信息, 否则作者和本社区有权追究责任。如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱: cloudbbs@huaweicloud.com
  • 点赞
  • 收藏
  • 关注作者

评论(0

0/1000
抱歉,系统识别当前为高风险访问,暂不支持该操作

全部回复

上滑加载中

设置昵称

在此一键设置昵称,即可参与社区互动!

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。