[Pytorch] ShufflenetV2模型迁移--昇腾910训练场景性能调优记录分享
主要流程1. 前向排查记录
2. 整网排查记录
3. python侧优化细节
# 原始channel_shuffle操作def channel_shuffle(x, groups):
# type: (torch.Tensor, int) -> torch.Tensor
batchsize, num_channels, height, width = x.data.size()
channels_per_group = num_channels // groups # reshape
x = x.view(batchsize, groups,
channels_per_group, height, width)
x = torch.transpose(x, 1, 2).contiguous()
# flatten
x = x.view(batchsize, -1, height, width)
return xclass InvertedResidual(nn.Module):
def __init__(self, inp, oup, stride):
super(InvertedResidual, self).__init__()
if not (1 <= stride <= 3):
raise ValueError('illegal stride value')
self.stride = stride
branch_features = oup // 2
assert (self.stride != 1) or (inp == branch_features << 1)
if self.stride > 1:
self.branch1 = nn.Sequential(
self.depthwise_conv(inp, inp, kernel_size=3, stride=self.stride, padding=1),
nn.BatchNorm2d(inp),
nn.Conv2d(inp, branch_features, kernel_size=1, stride=1, padding=0, bias=False),
nn.BatchNorm2d(branch_features),
nn.ReLU(inplace=True),
)
else:
self.branch1 = nn.Sequential()
self.branch2 = nn.Sequential(
nn.Conv2d(inp if (self.stride > 1) else branch_features,
branch_features, kernel_size=1, stride=1, padding=0, bias=False),
nn.BatchNorm2d(branch_features),
nn.ReLU(inplace=True),
self.depthwise_conv(branch_features, branch_features, kernel_size=3, stride=self.stride, padding=1),
nn.BatchNorm2d(branch_features),
nn.Conv2d(branch_features, branch_features, kernel_size=1, stride=1, padding=0, bias=False),
nn.BatchNorm2d(branch_features),
nn.ReLU(inplace=True),
)
@staticmethod def depthwise_conv(i, o, kernel_size, stride=1, padding=0, bias=False):
return nn.Conv2d(i, o, kernel_size, stride, padding, bias=bias, groups=i)
def forward(self, x):
if self.stride == 1:
x1, x2 = x.chunk(2, dim=1)
out = torch.cat((x1, self.branch2(x2)), dim=1)
else:
out = torch.cat((self.branch1(x), self.branch2(x)), dim=1)
out = channel_shuffle(out, 2)
return out
def channel_shuffle_index_select(x, groups=2): N, C, H, W = x.shape inp = C # channel_shuffle操作是对C维按一定规则的重排的工作,可以被表达为一次简单的重排 group_len = inp // groups index = torch.from_numpy(np.array(list(range(inp))).reshape(groups, group_len).transpose(1, 0).flatten()).long() x = x.index_select(1, index) return x# 对两个操作进行结果对比,可以看到语义是相等的x = torch.randn(2, 232, 14, 14)for group in [2, 4, 8]: out1 = channel_shuffle(x, group) out2 = channel_shuffle_index_select(x, group) print((out1 - out2).sum())
# 对应 out = channel_shuffle(torch.cat((self.branch1(x), self.branch2(x)), dim=1)) 的情形# 使用channel_shuffle_index_select替代channel_shuffle# 自定义OP,融合channel_shuffle_index_select和cat,使用计算类算子来消减非连续class IndexSelectFullImplementation(torch.autograd.Function):
@staticmethod def forward(ctx, x1, x2, fp_index, bp_index1, bp_index2):
# 强制流同步,仅稳定训练作用
stream = torch.npu.current_stream()
stream.synchronize()
# 对ctx注册bp_index1, bp_index2使反向时可以使用
ctx.bp_index1 = bp_index1
ctx.bp_index2 = bp_index2
x = torch.cat([x1, x2], dim=1)
# 使用index_select替代channel_shuffle操作,这里是后面不接chunk算子的场景
result = x.index_select(1, fp_index)
return result
@staticmethod def backward(ctx, grad_output):
# 强制流同步,仅稳定训练作用
stream = torch.npu.current_stream()
stream.synchronize()
# 由于index_select不支持5HD格式,将格式转换为NCHW来减少额外的transdata
grad_output.data = grad_output.data.npu_format_cast(0)
# 依据正向推导得到的反向的表达式,使用index_select同时完成对index_select和cat的反向
out1 = grad_output.index_select(1, ctx.bp_index1)
out2 = grad_output.index_select(1, ctx.bp_index2)
return out1, out2, None, None, None, Noneclass IndexSelectHalfImplementation(torch.autograd.Function):
@staticmethod def forward(ctx, x1, x2, fp_index1, fp_index2, bp_index1, bp_index2):
ctx.bp_index1 = bp_index1
ctx.bp_index2 = bp_index2
x = torch.cat([x1, x2], dim=1)
# 使用index_select替代channel_shuffle操作,这里是后面接chunk算子的场景
return x.index_select(1, fp_index1), x.index_select(1, fp_index2)
@staticmethod def backward(ctx, grad_output1, grad_output2):
grad_output = torch.cat([grad_output1, grad_output2], 1)
out1 = grad_output.index_select(1, ctx.bp_index1)
out2 = grad_output.index_select(1, ctx.bp_index2)
return out1, out2, None, None, None, Noneclass Channel_Shuffle(nn.Module):
def __init__(self, inp, groups=2, split_shuffle=True):
super(Channel_Shuffle, self).__init__()
self.split_shuffle = split_shuffle
self.group_len = inp // groups # 初始化channel_shuffle_index_select中需要使用的fp_index
self.out = np.array(list(range(inp))).reshape(groups, self.group_len).transpose(1, 0).flatten().tolist()
# 将初始化的fp_index按需注册为module的buffer,在to.device的时候顺路带到设备,减少h2dcopy的耗时
# 此处仅展示常用的group=2的场景下的使用方式,其他情形请自行拓展
if self.split_shuffle:
self.register_buffer('fp_index1', torch.tensor(self.out[:self.group_len], dtype=torch.int32))
self.register_buffer('fp_index2', torch.tensor(self.out[self.group_len:], dtype=torch.int32))
else:
self.register_buffer('fp_index', torch.tensor(self.out, dtype=torch.int32))
# 将对应的bp_index按需注册为module的buffer,在to.device的时候顺路带到设备,减少h2dcopy的耗时
self.register_buffer('bp_index1', torch.tensor(list(range(0, inp, 2)), dtype=torch.int32))
self.register_buffer('bp_index2', torch.tensor(list(range(1, inp, 2)), dtype=torch.int32))
def forward(self, x1, x2):
if self.split_shuffle:
return IndexSelectHalfImplementation.apply(x1, x2, self.fp_index1, self.fp_index2, self.bp_index1,
self.bp_index2)
else:
return IndexSelectFullImplementation.apply(x1, x2, self.fp_index, self.bp_index1, self.bp_index2)class InvertedResidual(nn.Module):
def __init__(self, inp, oup, stride, split_shuffle=True):
super(InvertedResidual, self).__init__()
if not (1 <= stride <= 3):
raise ValueError('illegal stride value')
self.stride = stride
branch_features = oup // 2
assert (self.stride != 1) or (inp == branch_features << 1)
if self.stride > 1:
self.branch1 = nn.Sequential(
self.depthwise_conv(inp, inp, kernel_size=3, stride=self.stride, padding=1),
nn.BatchNorm2d(inp),
nn.Conv2d(inp, branch_features, kernel_size=1, stride=1, padding=0, bias=False),
nn.BatchNorm2d(branch_features),
nn.ReLU(inplace=True),
)
else:
self.branch1 = nn.Sequential()
self.branch2 = nn.Sequential(
nn.Conv2d(inp if (self.stride > 1) else branch_features,
branch_features, kernel_size=1, stride=1, padding=0, bias=False),
nn.BatchNorm2d(branch_features),
nn.ReLU(inplace=True),
self.depthwise_conv(branch_features, branch_features, kernel_size=3, stride=self.stride, padding=1),
nn.BatchNorm2d(branch_features),
nn.Conv2d(branch_features, branch_features, kernel_size=1, stride=1, padding=0, bias=False),
nn.BatchNorm2d(branch_features),
nn.ReLU(inplace=True),
)
if self.stride > 1:
self.channel_shuffle = Channel_Shuffle(inp=branch_features + branch_features, groups=2,
split_shuffle=split_shuffle)
else:
self.channel_shuffle = Channel_Shuffle(inp=inp, groups=2, split_shuffle=split_shuffle)
@staticmethod def depthwise_conv(i, o, kernel_size, stride=1, padding=0, bias=False):
return nn.Conv2d(i, o, kernel_size, stride, padding, bias=bias, groups=i)
def forward(self, x):
# 删除concat和chunk操作,融合进self.channel_shuffle内处理
if self.stride == 1:
x1, x2 = x
x2 = self.branch2(x2)
else:
x1 = self.branch1(x)
x2 = self.branch2(x)
out = self.channel_shuffle(x1, x2)
return out |
华为开发者空间发布
让每位开发者拥有一台云主机
- 点赞
- 收藏
- 关注作者


评论(0)