PolarMask
【摘要】
import torch
import torch.nn as nn
import torchvision
def Conv3x3ReLU(in_channels,out_channels):
...
import torch
import torch.nn as nn
import torchvision
def Conv3x3ReLU(in_channels,out_channels):
return nn.Sequential(
nn.Conv2d(in_channels=in_channels,out_channels=out_channels,kernel_size=3,stride=1,padding=1),
nn.ReLU6(inplace=True)
)
def locLayer(in_channels,out_channels):
return nn.Sequential(
Conv3x3ReLU(in_channels=in_channels, out_channels=in_channels),
Conv3x3ReLU(in_channels=in_channels, out_channels=in_channels),
Conv3x3ReLU(in_channels=in_channels, out_channels=in_channels),
Conv3x3ReLU(in_channels=in_channels, out_channels=in_channels),
nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1),
)
def conf_centernessLayer(in_channels,out_channels):
return nn.Sequential(
Conv3x3ReLU(in_channels=in_channels, out_channels=in_channels),
Conv3x3ReLU(in_channels=in_channels, out_channels=in_channels),
Conv3x3ReLU(in_channels=in_channels, out_channels=in_channels),
Conv3x3ReLU(in_channels=in_channels, out_channels=in_channels),
nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1),
)
class PolarMask(nn.Module):
def __init__(self, num_classes=21):
super(PolarMask, self).__init__()
self.num_classes = num_classes
resnet = torchvision.models.resnet50()
layers = list(resnet.children())
self.layer1 = nn.Sequential(*layers[:5])
self.layer2 = nn.Sequential(*layers[5])
self.layer3 = nn.Sequential(*layers[6])
self.layer4 = nn.Sequential(*layers[7])
self.lateral5 = nn.Conv2d(in_channels=2048, out_channels=256, kernel_size=1)
self.lateral4 = nn.Conv2d(in_channels=1024, out_channels=256, kernel_size=1)
self.lateral3 = nn.Conv2d(in_channels=512, out_channels=256, kernel_size=1)
self.upsample4 = nn.ConvTranspose2d(in_channels=256, out_channels=256, kernel_size=4, stride=2, padding=1)
self.upsample3 = nn.ConvTranspose2d(in_channels=256, out_channels=256, kernel_size=4, stride=2, padding=1)
self.downsample6 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=2, padding=1)
self.downsample5 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=2, padding=1)
self.loc_layer3 = locLayer(in_channels=256,out_channels=36)
self.conf_centerness_layer3 = conf_centernessLayer(in_channels=256,out_channels=self.num_classes)
self.loc_layer4 = locLayer(in_channels=256, out_channels=36)
self.conf_centerness_layer4 = conf_centernessLayer(in_channels=256, out_channels=self.num_classes)
self.loc_layer5 = locLayer(in_channels=256, out_channels=36)
self.conf_centerness_layer5 = conf_centernessLayer(in_channels=256, out_channels=self.num_classes)
self.loc_layer6 = locLayer(in_channels=256, out_channels=36)
self.conf_centerness_layer6 = conf_centernessLayer(in_channels=256, out_channels=self.num_classes)
self.loc_layer7 = locLayer(in_channels=256, out_channels=36)
self.conf_centerness_layer7 = conf_centernessLayer(in_channels=256, out_channels=self.num_classes)
self.init_params()
def init_params(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
def forward(self, x):
x = self.layer1(x)
c3 =x = self.layer2(x)
c4 =x = self.layer3(x)
c5 = x = self.layer4(x)
p5 = self.lateral5(c5)
p4 = self.upsample4(p5) + self.lateral4(c4)
p3 = self.upsample3(p4) + self.lateral3(c3)
p6 = self.downsample5(p5)
p7 = self.downsample6(p6)
loc3 = self.loc_layer3(p3)
conf_centerness3 = self.conf_centerness_layer3(p3)
conf3, centerness3 = conf_centerness3.split([self.num_classes, 1], dim=1)
loc4 = self.loc_layer4(p4)
conf_centerness4 = self.conf_centerness_layer4(p4)
conf4, centerness4 = conf_centerness4.split([self.num_classes, 1], dim=1)
loc5 = self.loc_layer5(p5)
conf_centerness5 = self.conf_centerness_layer5(p5)
conf5, centerness5 = conf_centerness5.split([self.num_classes, 1], dim=1)
loc6 = self.loc_layer6(p6)
conf_centerness6 = self.conf_centerness_layer6(p6)
conf6, centerness6 = conf_centerness6.split([self.num_classes, 1], dim=1)
loc7 = self.loc_layer7(p7)
conf_centerness7 = self.conf_centerness_layer7(p7)
conf7, centerness7 = conf_centerness7.split([self.num_classes, 1], dim=1)
locs = torch.cat([loc3.permute(0, 2, 3, 1).contiguous().view(loc3.size(0), -1),
loc4.permute(0, 2, 3, 1).contiguous().view(loc4.size(0), -1),
loc5.permute(0, 2, 3, 1).contiguous().view(loc5.size(0), -1),
loc6.permute(0, 2, 3, 1).contiguous().view(loc6.size(0), -1),
loc7.permute(0, 2, 3, 1).contiguous().view(loc7.size(0), -1)],dim=1)
confs = torch.cat([conf3.permute(0, 2, 3, 1).contiguous().view(conf3.size(0), -1),
conf4.permute(0, 2, 3, 1).contiguous().view(conf4.size(0), -1),
conf5.permute(0, 2, 3, 1).contiguous().view(conf5.size(0), -1),
conf6.permute(0, 2, 3, 1).contiguous().view(conf6.size(0), -1),
conf7.permute(0, 2, 3, 1).contiguous().view(conf7.size(0), -1),], dim=1)
centernesses = torch.cat([centerness3.permute(0, 2, 3, 1).contiguous().view(centerness3.size(0), -1),
centerness4.permute(0, 2, 3, 1).contiguous().view(centerness4.size(0), -1),
centerness5.permute(0, 2, 3, 1).contiguous().view(centerness5.size(0), -1),
centerness6.permute(0, 2, 3, 1).contiguous().view(centerness6.size(0), -1),
centerness7.permute(0, 2, 3, 1).contiguous().view(centerness7.size(0), -1), ], dim=1)
out = (locs, confs, centernesses)
return out
if __name__ == '__main__':
model = PolarMask()
print(model)
input = torch.randn(1, 3, 800, 1024)
out = model(input)
print(out[0].shape)
print(out[1].shape)
print(out[2].shape)
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
- 27
- 28
- 29
- 30
- 31
- 32
- 33
- 34
- 35
- 36
- 37
- 38
- 39
- 40
- 41
- 42
- 43
- 44
- 45
- 46
- 47
- 48
- 49
- 50
- 51
- 52
- 53
- 54
- 55
- 56
- 57
- 58
- 59
- 60
- 61
- 62
- 63
- 64
- 65
- 66
- 67
- 68
- 69
- 70
- 71
- 72
- 73
- 74
- 75
- 76
- 77
- 78
- 79
- 80
- 81
- 82
- 83
- 84
- 85
- 86
- 87
- 88
- 89
- 90
- 91
- 92
- 93
- 94
- 95
- 96
- 97
- 98
- 99
- 100
- 101
- 102
- 103
- 104
- 105
- 106
- 107
- 108
- 109
- 110
- 111
- 112
- 113
- 114
- 115
- 116
- 117
- 118
- 119
- 120
- 121
- 122
- 123
- 124
- 125
- 126
- 127
- 128
- 129
- 130
- 131
- 132
- 133
- 134
- 135
- 136
- 137
- 138
文章来源: wanghao.blog.csdn.net,作者:AI浩,版权归原作者所有,如需转载,请联系作者。
原文链接:wanghao.blog.csdn.net/article/details/121607259
【版权声明】本文为华为云社区用户转载文章,如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱:
cloudbbs@huaweicloud.com
- 点赞
- 收藏
- 关注作者
评论(0)