教你如何dump算子的输入输出
一 环境准备
安装cann包和mindspore-lite(前面文章已经写了,这里不重复)
二 dump算子输入输出
当模型精度出现问题时,逐个dump算子的输入输出可以配合二分法定位是哪个算子的问题
或者
profiling数据分析后,发现某个算子耗时严重,dump算子的输入输出数据可以发现数据的特点
1. 使用benchmark工具dump数据
source /usr/local/Ascend/ascend-toolkit/set_env.sh && benchmark --modelFile=/xxx/f2f.mindir --device=Ascend310 --configFile=/xxx/config.cni
其中modelFile是使用convert工具将onnx模型文件转换得到的
配置文件config.cni内容如下:
[ascend_context]
dump_config_file=./dump.json
dump.json内容如下:
{
"dump":{
"dump_list":[
{
"model_name":"f2f"
}
],
"dump_path":"/xxx/dump_output",
"dump_mode":"all",
"dump_op_switch":"off"
}
}
需要注意的是:
1)model_name必须是模型的名称,不能带后缀,也不允许在前面添加文件路径,因此转换后的模型必须放在脚本同级目录下
2)dump输出目录必须提前创建
dump后会在输出目录下生成一些二进制文件,如下:
root@2cf0a033f372:/xxx/dump_output/20231008004857/0/f2f/1/0# ll
total 724764
drwx------ 2 root root 8192 Oct 8 00:48 ./
drwx------ 15 root root 126 Oct 8 00:49 ../
-rw------- 1 root root 7058214 Oct 8 00:48 ScatterNdUpdate.ScatterND_313.70.172.1696726138761095
-rw------- 1 root root 1573044 Oct 8 00:48 ScatterNdUpdate.ScatterND_72.10.172.1696726137326402
...
2. 将dump后数据生成np文件
python3 /usr/local/Ascend/ascend-toolkit/6.3.RC2/tools/operator_cmp/compare/msaccucmp.py convert -d /xxx/dump_output/20231008004857/0/f2f/1/0/ -out /xxx/dump_data_in_numpy
需要注意的是:
1)-d后面的路径必须到最底层目录
2)如果出现protobuf相关错误,可以执行 export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
执行成功后,可以在输出目录下看到生成的np文件,如下:
root@2cf0a033f372:/home/yuqing/f2f/output/f2f_fp16_aoe_1695624629/benchmark/dump_data_in_numpy# ll
total 725528
drwxr-xr-x 2 root root 20480 Oct 8 01:12 ./
drwxr-x--- 4 root root 313 Oct 8 00:57 ../
-rw------- 1 root root 1572992 Oct 8 01:12 ScatterNdUpdate.ScatterND_313.70.172.1696726138761095.input.0.npy
-rw------- 1 root root 3477728 Oct 8 01:12 ScatterNdUpdate.ScatterND_313.70.172.1696726138761095.input.1.npy
-rw------- 1 root root 434828 Oct 8 01:12 ScatterNdUpdate.ScatterND_313.70.172.1696726138761095.input.2.npy
-rw------- 1 root root 1572992 Oct 8 01:12 ScatterNdUpdate.ScatterND_313.70.172.1696726138761095.output.0.npy
...
可以使用numpy的b = np.load('xxx.npy') 读取npy文件,然后打印看看
由于本例中我们只想分析ScatterNdUpdate算子的输入输出,因此上面也只展示了与这个算子相关的
附:
ScatterNdUpdate算子一般有3个输入:ref, indices, updates
三个都是tensor,ref为待修改tensor,indices为修改的坐标,updates为修改使用的数据
举例说明:
1,1,512,512; 1,1,2; 1,1,512,512
表示使用updates的512*512数值一次替换ref相应的512*512位置的值
indices:1,1,2 需要拆分两部分,前面的1,1 表示更新的高纬度为1,1,后面的2表示高纬度的坐标,这里其值只能是(0,0)
同样
1,3,512,512; 1,3,225,322,4; 1,3,225,322
表示更新1,3,225,322个值,每个值的坐标用4维表示,每次只更新一个数
- 点赞
- 收藏
- 关注作者
评论(0)