归一化EvoNorms

举报
风吹稻花香 发表于 2021/06/04 23:39:32 2021/06/04
【摘要】 EvoNorms_PyTorch   https://github.com/lonePatient/EvoNorms_PyTorch 原版说精度提升了一个多点,但是内存占用比原来大了很多,也变慢了 import torchimport torch.nn as nnfrom torch.nn import initfrom torch.nn.parameter...

EvoNorms_PyTorch

 

https://github.com/lonePatient/EvoNorms_PyTorch

原版说精度提升了一个多点,但是内存占用比原来大了很多,也变慢了


  
  1. import torch
  2. import torch.nn as nn
  3. from torch.nn import init
  4. from torch.nn.parameter import Parameter
  5. def instance_std(x, eps=1e-5):
  6. N,C,H,W = x.size()
  7. x1 = x.reshape(N*C,-1)
  8. var = x1.var(dim=-1, keepdim=True)+eps
  9. return var.sqrt().reshape(N,C,1,1)
  10. def group_std(x, groups, eps = 1e-5):
  11. N, C, H, W = x.size()
  12. x1 = x.reshape(N,groups,-1)
  13. var = (x1.var(dim=-1, keepdim = True)+eps).reshape(N,groups,-1)
  14. return (x1 / var.sqrt()).reshape(N,C,H,W)
  15. class BatchNorm2dRelu(nn.Module):
  16. def __init__(self,in_channels):
  17. super(BatchNorm2dRelu,self).__init__()
  18. self.layer = nn.Sequential(
  19. nn.BatchNorm2d(in_channels),
  20. nn.ReLU(inplace=True))
  21. def forward(self, x):
  22. output = self.layer(x)
  23. return output
  24. class EvoNorm2dB0(nn.Module):
  25. def __init__(self,in_channels,nonlinear=True,momentum=0.9,eps = 1e-5):
  26. super(EvoNorm2dB0, self).__init__()
  27. self.nonlinear = nonlinear
  28. self.momentum = momentum
  29. self.eps = eps
  30. self.gamma = Parameter(torch.Tensor(1,in_channels,1,1))
  31. self.beta = Parameter(torch.Tensor(1,in_channels,1,1))
  32. if nonlinear:
  33. self.v = Parameter(torch.Tensor(1,in_channels,1,1))
  34. self.register_buffer('running_var', torch.ones(1, in_channels, 1, 1))
  35. self.reset_parameters()
  36. def reset_parameters(self):
  37. init.ones_(self.gamma)
  38. init.zeros_(self.beta)
  39. if self.nonlinear:
  40. init.ones_(self.v)
  41. def forward(self, x):
  42. N, C, H, W = x.size()
  43. if self.training:
  44. x1 = x.permute(1, 0, 2, 3).reshape(C, -1)
  45. var = x1.var(dim=1).reshape(1, C, 1, 1)
  46. self.running_var.copy_(self.momentum * self.running_var + (1 - self.momentum) * var)
  47. else:
  48. var = self.running_var
  49. if self.nonlinear:
  50. den = torch.max((var+self.eps).sqrt(), self.v * x + instance_std(x))
  51. return x / den * self.gamma + self.beta
  52. else:
  53. return x * self.gamma + self.beta
  54. class EvoNorm2dS0(nn.Module):
  55. def __init__(self,in_channels,groups=8,nonlinear=True):
  56. super(EvoNorm2dS0, self).__init__()
  57. self.nonlinear = nonlinear
  58. self.groups = groups
  59. self.gamma = Parameter(torch.Tensor(1,in_channels,1,1))
  60. self.beta = Parameter(torch.Tensor(1,in_channels,1,1))
  61. if nonlinear:
  62. self.v = Parameter(torch.Tensor(1,in_channels,1,1))
  63. self.reset_parameters()
  64. def reset_parameters(self):
  65. init.ones_(self.gamma)
  66. init.zeros_(self.beta)
  67. if self.nonlinear:
  68. init.ones_(self.v)
  69. def forward(self, x):
  70. if self.nonlinear:
  71. num = torch.sigmoid(self.v * x)
  72. std = group_std(x,self.groups)
  73. return num * std * self.gamma + self.beta
  74. else:
  75. return x * self.gamma + self.beta

 

文章来源: blog.csdn.net,作者:网奇,版权归原作者所有,如需转载,请联系作者。

原文链接:blog.csdn.net/jacke121/article/details/105434430

【版权声明】本文为华为云社区用户转载文章,如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱: cloudbbs@huaweicloud.com
  • 点赞
  • 收藏
  • 关注作者

评论(0

0/1000
抱歉,系统识别当前为高风险访问,暂不支持该操作

全部回复

上滑加载中

设置昵称

在此一键设置昵称,即可参与社区互动!

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。