[线上模型挑战] ICNet 模型迁移(TensorFlow GPU To Ascend 910)记录

举报
JeffDing 发表于 2020/12/12 05:33:07 2020/12/12
【摘要】 [线上模型挑战] ICNet 模型迁移(TensorFlow GPU To Ascend 910)记录, ICNet 主要用于图像实时语义分割,能够兼顾速度和准确性。 ICNet的主要思想是将输入图像变换为不同的分辨率,然后用不同计算复杂度的子网络计算不同分辨率的输入,然后将结果合并。

一、ICNET介绍


ICNet 主要用于图像实时语义分割,能够兼顾速度和准确性。 ICNet的主要思想是将输入图像变换为不同的分辨率,然后用不同计算复杂度的子网络计算不同分辨率的输入,然后将结果合并。

论文原文:https://arxiv.org/pdf/1704.08545.pdf

二、代码原型

https://github.com/hellochick/ICNet-tensorflow

三、预训练模型准备

通过COLAB平台将数据下载后拷贝到本地后再上传

from google_drive_downloader import GoogleDriveDownloader as gdd
gdd.download_file_from_google_drive(file_id='15S_vZoZZwBsORxtRAMcbdsI99o6Cvo5x',
                                        dest_path='./model/cityscapes/icnet_cityscapes_train_30k_bnnomerge.npy',
                                        unzip=False)
gdd.download_file_from_google_drive(file_id='17ZILbQ7Qazg7teb567CIPJ30FD57bVVg',
                                        dest_path='./model/cityscapes/icnet_cityscapes_train_30k.npy',
                                        unzip=False)
gdd.download_file_from_google_drive(file_id='1Z-slNrKYJpfpELeuh2UlueQG1krF9I4a',
                                        dest_path='./model/cityscapes/icnet_cityscapes_trainval_90k_bnnomerge.npy',
                                        unzip=False)
gdd.download_file_from_google_drive(file_id='1tZIHpppPcleamBlXKSzjOqL93gNjWGec',
                                        dest_path='./model/cityscapes/icnet_cityscapes_trainval_90k.npy',
                                        unzip=False)

四、数据集准备

cityscapes:数据集下载需要注册且不能是公共域名下的邮箱,不过在主办方提供的环境的数据集中有该数据集,直接复制到代码目录下的data/cityscapes_dataset目录下

cp -r /data/dataset/storage/cityscapes/ ./data/cityscapes_dataset/

ade20k:

wget -O ./data/ADEChallengeData2016.zip http://data.csail.mit.edu/places/ADEchallenge/ADEChallengeData2016.zip
unzip ./data/ADEChallengeData2016.zip -d ./data
rm ./data/ADEChallengeData2016.zip
echo "Dataset downloaded."

五、修改代码

可以参考开发指南:https://support.huaweicloud.com/mprtg-A800_9000_9010/atlasprtg_13_0006.html

主要修改network.py文件

在代码开头引入NPU包

from npu_bridge.estimator import npu_ops
from tensorflow.core.protobuf.rewriter_config_pb2 import RewriterConfig

在def create_session中的sess.run前面加上以下语句块

def create_session(self):
        # Set up tf session and initialize variables.
        #config = tf.ConfigProto()
        #config.gpu_options.allow_growth = True
        
        global_init = tf.global_variables_initializer()
        local_init = tf.local_variables_initializer()
        
        config = tf.ConfigProto()
        custom_op =  config.graph_options.rewrite_options.custom_optimizers.add()
        custom_op.name =  "NpuOptimizer"
        custom_op.parameter_map["use_off_line"].b = True # 必须显示开启,在昇腾AI处理器执行训练
        config.graph_options.rewrite_options.remapping = RewriterConfig.OFF  # 必须显示关闭remap
        #config.graph_options.rewrite_options.optimizers.extend(["GradFusionOptimizer"]) #分布式添加
        
        self.sess = tf.Session(config=config)
        self.sess.run([global_init, local_init])config = tf.ConfigProto()
            custom_op =  config.graph_options.rewrite_options.custom_optimizers.add()
            custom_op.name =  "NpuOptimizer"
            custom_op.parameter_map["use_off_line"].b = True # 必须显示开启,在昇腾AI处理器执行训练
            config.graph_options.rewrite_options.remapping = RewriterConfig.OFF  # 必须显示关闭remap
            #config.graph_options.rewrite_options.optimizers.extend(["GradFusionOptimizer"]) #分布式添加

可能network.py中也需要修改一下

在np.load中加入allow_pickle=True,这一参数

六、运行结果

python train.py \
  --update-mean-var \
  --train-beta-gamma \
  --random-scale \
  --random-mirror \
  --dataset=cityscapes \
  --filter-scale=1 >train.log 2>&1

七、初次运行失败原因

要想在框架不支持动态shape的背景下跑通网络,暂时只能走规避手段,也就是让所有where的输入节点在编译阶段就可以常量折叠成const,这样就可以在推到shape时拿到输入的实际数据,从而避免推导出动态shape。

上述规避措施依赖两个条件:
1、where的输入节点经过多个节点后最终都来源于const的常量
2、where的输入连接到常量的路径上,所有经过的节点都完成常量折叠的实现

根据上述的条件去图中查找相关接口,结果如下:
6个where的输入都来源于类型为LessEqual的节点,且结构相似,仅选取一例做示例:
如下图所示,where的输入最终来源于3个constant(蓝色部分),1个data(红色部分),有一个输入来源自网络的输入,因此上述的条件一就不满足了。无法通过常量折叠方式规避此问题。
所以只能依赖框架对动态shape的支持能力完备后,此网络才可以继续展开调试。

【版权声明】本文为华为云社区用户原创内容,转载时必须标注文章的来源(华为云社区)、文章链接、文章作者等基本信息, 否则作者和本社区有权追究责任。如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱: cloudbbs@huaweicloud.com
  • 点赞
  • 收藏
  • 关注作者

评论(0

0/1000
抱歉,系统识别当前为高风险访问,暂不支持该操作

全部回复

上滑加载中

设置昵称

在此一键设置昵称,即可参与社区互动!

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。

举报
请填写举报理由
0/200