Pytorch gpu加速方法

举报
风吹稻花香 发表于 2021/09/09 23:08:19 2021/09/09
【摘要】 Pytorch gpu加速方法 原文: https://www.zhihu.com/question/274635237 relu 用 inplace=True用 eval() 和 with torch.no_grad():每个 batch 后认真的把所有参数从 GPU 拿出来后删除虽然...

Pytorch gpu加速方法

原文:

https://www.zhihu.com/question/274635237

  • relu 用 inplace=True
  • 用 eval() 和 with torch.no_grad():
  • 每个 batch 后认真的把所有参数从 GPU 拿出来后删除
  • 虽然很多回答建议用, 但我建议不要用 torch.cuda.empty_cache() , 这只是释放 GPU 缓存而使得 nvidia-smi 能看得见 pytorch 自动释放的内存而已. 99% 的用户不需要使用这个命令. 并有用户反应每次用反而会减慢 1~2s.[1]
  • 注意: 当每张 GPU 里面的 batch_size 太小(<8)时用 batch_norm 会导致训练不稳定, 除非你用以下所说的 APEX 来实现多 GPU sync_bn
  • torch.backends.cudnn.deterministic = True 用不用对 GPU 内存占用和效率都没有什么太大的影响. 建议开着.
  • 不要用 .cpu() 来取 GPU 里面出来的图片. 这样做的话训练时长可能翻倍.

实现: 研究 pytorch 官方架构就会发现大部分 forward pass 都是 `x = self.conv(x)` 的形式, 很少 introduce new variable. 所以: (1) 把不需要的变量都由 `x` 代替; (2) 变量用完后用 `del` 删除.

例子


  
  1. def forward(self, x):
  2. conv2 = self.conv2(self.conv1(x)) #1/4
  3. del x
  4. conv3 = self.conv3(conv2) #1/8
  5. conv4 = self.conv4(conv3) #1/16
  6. conv5 = self.conv5(conv4) #1/32
  7. center_64 = self.center_conv1x1(self.center_global_pool(conv5))
  8. d5 = self.decoder5(self.center(conv5), conv5)
  9. del conv5
  10. d4 = self.decoder4(d5, conv4)
  11. del conv4
  12. d3 = self.decoder3(d4, conv3)
  13. del conv3
  14. d2 = self.decoder2(d3, conv2)
  15. del conv2

如果你按照上面的方法把 pin_memory 开启了的话, 请数据放入 GPU 的时候把 non_blocking 开启. 这样如果你只把数据放入 GPU 而不把数据从 GPU 拿出来再做计算的话就会加快很多 (据用户报告可加速 50%). 就算你把 GPU 中数据拿出来 (ie. 用了 .cpu() 命令, 最坏的结果也是与 non_blocking=False 相当:


  
  1. """Sync Point"""
  2. image = image.cuda(non_blocking=True)
  3. labels = labels.cuda(non_blocking=True).float()
  4. """Async Point"""
  5. prediction = net(image)

 

文章来源: blog.csdn.net,作者:AI视觉网奇,版权归原作者所有,如需转载,请联系作者。

原文链接:blog.csdn.net/jacke121/article/details/120190974

【版权声明】本文为华为云社区用户转载文章,如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱: cloudbbs@huaweicloud.com
  • 点赞
  • 收藏
  • 关注作者

评论(0

0/1000
抱歉,系统识别当前为高风险访问,暂不支持该操作

全部回复

上滑加载中

设置昵称

在此一键设置昵称,即可参与社区互动!

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。