主要流程1. 前向排查记录 2. 整网排查记录 3. python侧优化细节
1. 前向排查记录
由于原生实现的torch.transpose(x, 1, 2).contiguous()在NPU上效率低且无有效替代方法,使用channel_shuffle_index_select,在语义相同的情况下使用计算类算子替换框架类算子,从而减少耗时
由于shufflenetv2中含有大量的chunk操作,而chunk操作在Pytorch中为框架类算子,其结果会将一个tensor分割为几个等长的非连续的tensor,而非连续转连续这个操作目前耗时较长,故使用计算类算子消除非连续
适配层在适配算子时默认指定输出格式为输入格式,但是concat不支持C轴非16整数倍的5HD的格式,会转为4D进行处理,但是concat后面接的是gatherv2算子,也是仅支持4D格式的算子,所以导致5HD->4D->concat->5HD->4D->gatherv2->5HD,修改concat输出格式,当非16整数倍时指定输出格式为4D,优化后5HD->4D->concat->gatherv2->5HD
设置weight初始化格式避免计算过程中反复的transdata
修复了DWCONV weight输出格式指定,避免一些不必要5HD->4D
2. 整网排查记录
使用计算类算子替换框架类算子
使用buffer记录index信息到npu,消除 index.to('npu') 的操作
使用计算类算子消除非连续
contiguous_with_gatherv2是使用aicore算子GatherV2来完成非连续转连续操作,但是这个
修改batchsize
修改batchsize + contiguous_with_gatherv2
由于concat算子的反向是chunk,会引起非连续问题,故自定义concat算子反向,使用Gatherv2替代chunk,将其融合成cat+shuffle+chunk,消除不连续
ReluGrad算子有两个输入:grad_output(反向的输入),self(正向的输出),在shufflenet中有时会出现4D + 5HD的场景,而FE的格式对齐往往对齐第一个tensor的format,结果就会导致(4D, 5HD)->(4D, 4D)->ReluGrad->4D->5HD。由于正向的输出格式基本就是输入格式,而relu往往是配合在Conv+BN+Relu这样使用,所以可以认为,在这个场景下,输出5HD是更合适的选择。于是手动插入npu_format_cast,(4D, 5HD)->(5HD, 5HD)->ReluGrad->5HD
IndexSelectFullImplementation中涉及到了对一个5HD的tensor做两次gatherv2操作,这个时候会导致两次的5HD->4D,可以手动先做一次5HD->4D,这样就可以在gatherv2时不做transdata,从而消减一次transdata操作
加入混合精度O1 O2
由于Axpy算子的参数校验,所有网络在参数更新时,如C不整除16则会transdata为4D进行Axpy运算,引入了大量的transdata算子,通过增加了一个卫函数,当Axpy的input的shape一致时结束校验,从而避免了格式转换,增加了运行效率
删除所有的流同步操作 -- 容易导致不收敛,没有采纳
使用针对非对齐优化后的Gatherv2算子后,整体性能提速至交付水平
使用针对ShufflenetV2场景再次优化后的Gatherv3算子后,整体性能还能继续提升
3. python侧优化细节
# 原始channel_shuffle操作def channel_shuffle(x, groups):
# type: (torch.Tensor, int) -> torch.Tensor
batchsize, num_channels, height, width = x.data.size()
channels_per_group = num_channels // groups # reshape
x = x.view(batchsize, groups,
channels_per_group, height, width)
x = torch.transpose(x, 1, 2).contiguous()
# flatten
x = x.view(batchsize, -1, height, width)
return xclass InvertedResidual(nn.Module):
def __init__(self, inp, oup, stride):
super(InvertedResidual, self).__init__()
if not (1 <= stride <= 3):
raise ValueError('illegal stride value')
self.stride = stride
branch_features = oup // 2
assert (self.stride != 1) or (inp == branch_features << 1)
if self.stride > 1:
self.branch1 = nn.Sequential(
self.depthwise_conv(inp, inp, kernel_size=3, stride=self.stride, padding=1),
nn.BatchNorm2d(inp),
nn.Conv2d(inp, branch_features, kernel_size=1, stride=1, padding=0, bias=False),
nn.BatchNorm2d(branch_features),
nn.ReLU(inplace=True),
)
else:
self.branch1 = nn.Sequential()
self.branch2 = nn.Sequential(
nn.Conv2d(inp if (self.stride > 1) else branch_features,
branch_features, kernel_size=1, stride=1, padding=0, bias=False),
nn.BatchNorm2d(branch_features),
nn.ReLU(inplace=True),
self.depthwise_conv(branch_features, branch_features, kernel_size=3, stride=self.stride, padding=1),
nn.BatchNorm2d(branch_features),
nn.Conv2d(branch_features, branch_features, kernel_size=1, stride=1, padding=0, bias=False),
nn.BatchNorm2d(branch_features),
nn.ReLU(inplace=True),
)
@staticmethod def depthwise_conv(i, o, kernel_size, stride=1, padding=0, bias=False):
return nn.Conv2d(i, o, kernel_size, stride, padding, bias=bias, groups=i)
def forward(self, x):
if self.stride == 1:
x1, x2 = x.chunk(2, dim=1)
out = torch.cat((x1, self.branch2(x2)), dim=1)
else:
out = torch.cat((self.branch1(x), self.branch2(x)), dim=1)
out = channel_shuffle(out, 2)
return out
def channel_shuffle_index_select(x, groups=2):
N, C, H, W = x.shape
inp = C # channel_shuffle操作是对C维按一定规则的重排的工作,可以被表达为一次简单的重排
group_len = inp // groups
index = torch.from_numpy(np.array(list(range(inp))).reshape(groups, group_len).transpose(1, 0).flatten()).long()
x = x.index_select(1, index)
return x# 对两个操作进行结果对比,可以看到语义是相等的x = torch.randn(2, 232, 14, 14)for group in [2, 4, 8]:
out1 = channel_shuffle(x, group)
out2 = channel_shuffle_index_select(x, group)
print((out1 - out2).sum())
# 对应 out = channel_shuffle(torch.cat((self.branch1(x), self.branch2(x)), dim=1)) 的情形# 使用channel_shuffle_index_select替代channel_shuffle# 自定义OP,融合channel_shuffle_index_select和cat,使用计算类算子来消减非连续class IndexSelectFullImplementation(torch.autograd.Function):
@staticmethod def forward(ctx, x1, x2, fp_index, bp_index1, bp_index2):
# 强制流同步,仅稳定训练作用
stream = torch.npu.current_stream()
stream.synchronize()
# 对ctx注册bp_index1, bp_index2使反向时可以使用
ctx.bp_index1 = bp_index1
ctx.bp_index2 = bp_index2
x = torch.cat([x1, x2], dim=1)
# 使用index_select替代channel_shuffle操作,这里是后面不接chunk算子的场景
result = x.index_select(1, fp_index)
return result
@staticmethod def backward(ctx, grad_output):
# 强制流同步,仅稳定训练作用
stream = torch.npu.current_stream()
stream.synchronize()
# 由于index_select不支持5HD格式,将格式转换为NCHW来减少额外的transdata
grad_output.data = grad_output.data.npu_format_cast(0)
# 依据正向推导得到的反向的表达式,使用index_select同时完成对index_select和cat的反向
out1 = grad_output.index_select(1, ctx.bp_index1)
out2 = grad_output.index_select(1, ctx.bp_index2)
return out1, out2, None, None, None, Noneclass IndexSelectHalfImplementation(torch.autograd.Function):
@staticmethod def forward(ctx, x1, x2, fp_index1, fp_index2, bp_index1, bp_index2):
ctx.bp_index1 = bp_index1
ctx.bp_index2 = bp_index2
x = torch.cat([x1, x2], dim=1)
# 使用index_select替代channel_shuffle操作,这里是后面接chunk算子的场景
return x.index_select(1, fp_index1), x.index_select(1, fp_index2)
@staticmethod def backward(ctx, grad_output1, grad_output2):
grad_output = torch.cat([grad_output1, grad_output2], 1)
out1 = grad_output.index_select(1, ctx.bp_index1)
out2 = grad_output.index_select(1, ctx.bp_index2)
return out1, out2, None, None, None, Noneclass Channel_Shuffle(nn.Module):
def __init__(self, inp, groups=2, split_shuffle=True):
super(Channel_Shuffle, self).__init__()
self.split_shuffle = split_shuffle
self.group_len = inp // groups # 初始化channel_shuffle_index_select中需要使用的fp_index
self.out = np.array(list(range(inp))).reshape(groups, self.group_len).transpose(1, 0).flatten().tolist()
# 将初始化的fp_index按需注册为module的buffer,在to.device的时候顺路带到设备,减少h2dcopy的耗时
# 此处仅展示常用的group=2的场景下的使用方式,其他情形请自行拓展
if self.split_shuffle:
self.register_buffer('fp_index1', torch.tensor(self.out[:self.group_len], dtype=torch.int32))
self.register_buffer('fp_index2', torch.tensor(self.out[self.group_len:], dtype=torch.int32))
else:
self.register_buffer('fp_index', torch.tensor(self.out, dtype=torch.int32))
# 将对应的bp_index按需注册为module的buffer,在to.device的时候顺路带到设备,减少h2dcopy的耗时
self.register_buffer('bp_index1', torch.tensor(list(range(0, inp, 2)), dtype=torch.int32))
self.register_buffer('bp_index2', torch.tensor(list(range(1, inp, 2)), dtype=torch.int32))
def forward(self, x1, x2):
if self.split_shuffle:
return IndexSelectHalfImplementation.apply(x1, x2, self.fp_index1, self.fp_index2, self.bp_index1,
self.bp_index2)
else:
return IndexSelectFullImplementation.apply(x1, x2, self.fp_index, self.bp_index1, self.bp_index2)class InvertedResidual(nn.Module):
def __init__(self, inp, oup, stride, split_shuffle=True):
super(InvertedResidual, self).__init__()
if not (1 <= stride <= 3):
raise ValueError('illegal stride value')
self.stride = stride
branch_features = oup // 2
assert (self.stride != 1) or (inp == branch_features << 1)
if self.stride > 1:
self.branch1 = nn.Sequential(
self.depthwise_conv(inp, inp, kernel_size=3, stride=self.stride, padding=1),
nn.BatchNorm2d(inp),
nn.Conv2d(inp, branch_features, kernel_size=1, stride=1, padding=0, bias=False),
nn.BatchNorm2d(branch_features),
nn.ReLU(inplace=True),
)
else:
self.branch1 = nn.Sequential()
self.branch2 = nn.Sequential(
nn.Conv2d(inp if (self.stride > 1) else branch_features,
branch_features, kernel_size=1, stride=1, padding=0, bias=False),
nn.BatchNorm2d(branch_features),
nn.ReLU(inplace=True),
self.depthwise_conv(branch_features, branch_features, kernel_size=3, stride=self.stride, padding=1),
nn.BatchNorm2d(branch_features),
nn.Conv2d(branch_features, branch_features, kernel_size=1, stride=1, padding=0, bias=False),
nn.BatchNorm2d(branch_features),
nn.ReLU(inplace=True),
)
if self.stride > 1:
self.channel_shuffle = Channel_Shuffle(inp=branch_features + branch_features, groups=2,
split_shuffle=split_shuffle)
else:
self.channel_shuffle = Channel_Shuffle(inp=inp, groups=2, split_shuffle=split_shuffle)
@staticmethod def depthwise_conv(i, o, kernel_size, stride=1, padding=0, bias=False):
return nn.Conv2d(i, o, kernel_size, stride, padding, bias=bias, groups=i)
def forward(self, x):
# 删除concat和chunk操作,融合进self.channel_shuffle内处理
if self.stride == 1:
x1, x2 = x
x2 = self.branch2(x2)
else:
x1 = self.branch1(x)
x2 = self.branch2(x)
out = self.channel_shuffle(x1, x2)
return out |
评论(0)