[Pytorch] ShufflenetV2模型迁移--昇腾910训练场景性能调优记录分享
主要流程1. 前向排查记录
2. 整网排查记录
3. python侧优化细节
# 原始channel_shuffle操作def channel_shuffle(x, groups): # type: (torch.Tensor, int) -> torch.Tensor batchsize, num_channels, height, width = x.data.size() channels_per_group = num_channels // groups # reshape x = x.view(batchsize, groups, channels_per_group, height, width) x = torch.transpose(x, 1, 2).contiguous() # flatten x = x.view(batchsize, -1, height, width) return xclass InvertedResidual(nn.Module): def __init__(self, inp, oup, stride): super(InvertedResidual, self).__init__() if not (1 <= stride <= 3): raise ValueError('illegal stride value') self.stride = stride branch_features = oup // 2 assert (self.stride != 1) or (inp == branch_features << 1) if self.stride > 1: self.branch1 = nn.Sequential( self.depthwise_conv(inp, inp, kernel_size=3, stride=self.stride, padding=1), nn.BatchNorm2d(inp), nn.Conv2d(inp, branch_features, kernel_size=1, stride=1, padding=0, bias=False), nn.BatchNorm2d(branch_features), nn.ReLU(inplace=True), ) else: self.branch1 = nn.Sequential() self.branch2 = nn.Sequential( nn.Conv2d(inp if (self.stride > 1) else branch_features, branch_features, kernel_size=1, stride=1, padding=0, bias=False), nn.BatchNorm2d(branch_features), nn.ReLU(inplace=True), self.depthwise_conv(branch_features, branch_features, kernel_size=3, stride=self.stride, padding=1), nn.BatchNorm2d(branch_features), nn.Conv2d(branch_features, branch_features, kernel_size=1, stride=1, padding=0, bias=False), nn.BatchNorm2d(branch_features), nn.ReLU(inplace=True), ) @staticmethod def depthwise_conv(i, o, kernel_size, stride=1, padding=0, bias=False): return nn.Conv2d(i, o, kernel_size, stride, padding, bias=bias, groups=i) def forward(self, x): if self.stride == 1: x1, x2 = x.chunk(2, dim=1) out = torch.cat((x1, self.branch2(x2)), dim=1) else: out = torch.cat((self.branch1(x), self.branch2(x)), dim=1) out = channel_shuffle(out, 2) return out
def channel_shuffle_index_select(x, groups=2): N, C, H, W = x.shape inp = C # channel_shuffle操作是对C维按一定规则的重排的工作,可以被表达为一次简单的重排 group_len = inp // groups index = torch.from_numpy(np.array(list(range(inp))).reshape(groups, group_len).transpose(1, 0).flatten()).long() x = x.index_select(1, index) return x# 对两个操作进行结果对比,可以看到语义是相等的x = torch.randn(2, 232, 14, 14)for group in [2, 4, 8]: out1 = channel_shuffle(x, group) out2 = channel_shuffle_index_select(x, group) print((out1 - out2).sum())
# 对应 out = channel_shuffle(torch.cat((self.branch1(x), self.branch2(x)), dim=1)) 的情形# 使用channel_shuffle_index_select替代channel_shuffle# 自定义OP,融合channel_shuffle_index_select和cat,使用计算类算子来消减非连续class IndexSelectFullImplementation(torch.autograd.Function): @staticmethod def forward(ctx, x1, x2, fp_index, bp_index1, bp_index2): # 强制流同步,仅稳定训练作用 stream = torch.npu.current_stream() stream.synchronize() # 对ctx注册bp_index1, bp_index2使反向时可以使用 ctx.bp_index1 = bp_index1 ctx.bp_index2 = bp_index2 x = torch.cat([x1, x2], dim=1) # 使用index_select替代channel_shuffle操作,这里是后面不接chunk算子的场景 result = x.index_select(1, fp_index) return result @staticmethod def backward(ctx, grad_output): # 强制流同步,仅稳定训练作用 stream = torch.npu.current_stream() stream.synchronize() # 由于index_select不支持5HD格式,将格式转换为NCHW来减少额外的transdata grad_output.data = grad_output.data.npu_format_cast(0) # 依据正向推导得到的反向的表达式,使用index_select同时完成对index_select和cat的反向 out1 = grad_output.index_select(1, ctx.bp_index1) out2 = grad_output.index_select(1, ctx.bp_index2) return out1, out2, None, None, None, Noneclass IndexSelectHalfImplementation(torch.autograd.Function): @staticmethod def forward(ctx, x1, x2, fp_index1, fp_index2, bp_index1, bp_index2): ctx.bp_index1 = bp_index1 ctx.bp_index2 = bp_index2 x = torch.cat([x1, x2], dim=1) # 使用index_select替代channel_shuffle操作,这里是后面接chunk算子的场景 return x.index_select(1, fp_index1), x.index_select(1, fp_index2) @staticmethod def backward(ctx, grad_output1, grad_output2): grad_output = torch.cat([grad_output1, grad_output2], 1) out1 = grad_output.index_select(1, ctx.bp_index1) out2 = grad_output.index_select(1, ctx.bp_index2) return out1, out2, None, None, None, Noneclass Channel_Shuffle(nn.Module): def __init__(self, inp, groups=2, split_shuffle=True): super(Channel_Shuffle, self).__init__() self.split_shuffle = split_shuffle self.group_len = inp // groups # 初始化channel_shuffle_index_select中需要使用的fp_index self.out = np.array(list(range(inp))).reshape(groups, self.group_len).transpose(1, 0).flatten().tolist() # 将初始化的fp_index按需注册为module的buffer,在to.device的时候顺路带到设备,减少h2dcopy的耗时 # 此处仅展示常用的group=2的场景下的使用方式,其他情形请自行拓展 if self.split_shuffle: self.register_buffer('fp_index1', torch.tensor(self.out[:self.group_len], dtype=torch.int32)) self.register_buffer('fp_index2', torch.tensor(self.out[self.group_len:], dtype=torch.int32)) else: self.register_buffer('fp_index', torch.tensor(self.out, dtype=torch.int32)) # 将对应的bp_index按需注册为module的buffer,在to.device的时候顺路带到设备,减少h2dcopy的耗时 self.register_buffer('bp_index1', torch.tensor(list(range(0, inp, 2)), dtype=torch.int32)) self.register_buffer('bp_index2', torch.tensor(list(range(1, inp, 2)), dtype=torch.int32)) def forward(self, x1, x2): if self.split_shuffle: return IndexSelectHalfImplementation.apply(x1, x2, self.fp_index1, self.fp_index2, self.bp_index1, self.bp_index2) else: return IndexSelectFullImplementation.apply(x1, x2, self.fp_index, self.bp_index1, self.bp_index2)class InvertedResidual(nn.Module): def __init__(self, inp, oup, stride, split_shuffle=True): super(InvertedResidual, self).__init__() if not (1 <= stride <= 3): raise ValueError('illegal stride value') self.stride = stride branch_features = oup // 2 assert (self.stride != 1) or (inp == branch_features << 1) if self.stride > 1: self.branch1 = nn.Sequential( self.depthwise_conv(inp, inp, kernel_size=3, stride=self.stride, padding=1), nn.BatchNorm2d(inp), nn.Conv2d(inp, branch_features, kernel_size=1, stride=1, padding=0, bias=False), nn.BatchNorm2d(branch_features), nn.ReLU(inplace=True), ) else: self.branch1 = nn.Sequential() self.branch2 = nn.Sequential( nn.Conv2d(inp if (self.stride > 1) else branch_features, branch_features, kernel_size=1, stride=1, padding=0, bias=False), nn.BatchNorm2d(branch_features), nn.ReLU(inplace=True), self.depthwise_conv(branch_features, branch_features, kernel_size=3, stride=self.stride, padding=1), nn.BatchNorm2d(branch_features), nn.Conv2d(branch_features, branch_features, kernel_size=1, stride=1, padding=0, bias=False), nn.BatchNorm2d(branch_features), nn.ReLU(inplace=True), ) if self.stride > 1: self.channel_shuffle = Channel_Shuffle(inp=branch_features + branch_features, groups=2, split_shuffle=split_shuffle) else: self.channel_shuffle = Channel_Shuffle(inp=inp, groups=2, split_shuffle=split_shuffle) @staticmethod def depthwise_conv(i, o, kernel_size, stride=1, padding=0, bias=False): return nn.Conv2d(i, o, kernel_size, stride, padding, bias=bias, groups=i) def forward(self, x): # 删除concat和chunk操作,融合进self.channel_shuffle内处理 if self.stride == 1: x1, x2 = x x2 = self.branch2(x2) else: x1 = self.branch1(x) x2 = self.branch2(x) out = self.channel_shuffle(x1, x2) return out |
华为开发者空间发布
让每位开发者拥有一台云主机
- 点赞
- 收藏
- 关注作者
评论(0)