Pytorch网络结构可视化
现在用的这个:
net = load_model(net, args.trained_model, args.cpu)
for name, param in net.named_parameters():
print(param.size(),name, )
下面这个报错:
Pytorch网络结构可视化
使用这个网址,超级强悍
https://lutzroeder.github.io/netron/
可以通过以下的命令进行安装
pip install graphviz
pip install torch torchvision
pip install tensorwatch
载入库
import sys
import torch
import tensorwatch as tw
import torchvision.models
网络结构可视化
model = torchvision.models.alexnet()
tw.draw_model(model, [1, 3, 224, 224])
载入alexnet,draw_model函数需要传入三个参数,第一个为model,第二个参数为input_shape,第三个参数为orientation,可以选择’LR’或者’TB’,分别代表左右布局与上下布局。
在notebook中,执行完上面的代码会显示如下的图,将网络的结构及各个层的name和shape进行了可视化。
统计网络参数
可以通过model_stats方法统计各层的参数情况。
tw.model_stats(alexnet_model, [1, 3, 224,
文章来源: blog.csdn.net,作者:网奇,版权归原作者所有,如需转载,请联系作者。
原文链接:blog.csdn.net/jacke121/article/details/103689899
- 点赞
- 收藏
- 关注作者
评论(0)