Ascend上的FlashAttention实现
1 FlashAttention
FlashAttention是一种优化Transformer模型计算效率和内存使用的技术。它通过减少存储访问开销(Memory Access Cost,MAC),而非降低FLOPS(浮点运算次数),来提升性能。
2 前述知识点
涉及到内存访问,肯定与计算的硬件架构有关系。
从GPU架构进行解析,参考如下博客:
大模型推理加速技术的学习路线是什么
首先,我们将探讨GPU架构,特别是其内存层次结构。我们确定了两个重要模式:计算限制(compute bound)和内存限制(memory bound),并讨论了大规模Transformer推理受内存限制的原因。大部分优化都基于Transformer推理受内存限制这一基本事实,例如只要我们提高FLOP利用率,就能提高效率。
2.1 GPU架构
GPU架构总体如下图所示:
基础部分:DRAM(动态随机存取存储器)、L2缓存和SM(流处理器单元)
- 与CPU对比
- SM类似于CPU核心,但具有更高级的并行性;
- L2缓存和DRAM类似于CPU的L2缓存和DRAM
- 在Flash Attention论文中,L2缓存被称为SRAM(静态随机存取存储器)
- A100 80G SXM
- 08个SM,DRAM容量为80GB,有40M L2缓存
SM内部包含什么?
- L1缓存:指令和数据
- 张量核心:进行矩阵乘法运算的地方。回想一下,神经网络计算基本上就是巨大批量的矩阵乘法。
GPU编程基础
在执行model.generate(prompt)时,我们进行以下操作:
- 内存访问:
- 从高带宽内存(HBM)加载模型权重 -> L2缓存 -> 传输到SM(流处理器单元)
- 计算:
- 在SM中执行矩阵乘法,SM请求张量核心执行计算
- A100:
- 108个SM,DRAM容量为80G,40M L2缓存
- bf16张量核心:每秒312万亿浮点运算(TFLOPS)
- DRAM内存带宽为2039GB/秒 = 2.039T/秒
- 如果模型很大,我们将其分割到多个GPU上,比如两个由NVLink连接的GPU
- NVLink 300GB/秒 = 0.3T/秒
- 我们大致观察了速度层次结构。尽管不能直接比较,但它们的数量级差异是我们需要优化的主要方面:
- 312T(SM计算) > 2.03T(DRAM内存访问) > 0.3T=300G(NVLink跨设备通信) > 60G(PCIe跨设备通信)
- 这意味着,如果我们希望速度更快,我们应该尽力:
- 充分利用SM
- 减少单个GPU的内存访问(因为它比计算慢得多),减少GPU之间的通信(因为它甚至比内存访问还要慢)。
计算限制与内存限制
如何确定我们是否充分利用了SM呢?我们通过以下方式检查是否计算或内存限制:
定义每字节GPU操作 = flop / 内存带宽
-
A100 = 312 / 2.039
-
定义计算强度 = 计算 / 内存访问
-
如果计算强度大,说明程序更会受到计算限制;如果计算强度较小,则更受内存限制。
-
增加批次大小会将行为从内存限制变为计算限制。
-
内核融合:减少了内存访问操作,因为我们将多个操作合并为一个操作。
2.2 Transformer推理
内存布局
正如我们所看到的,为了在bf16格式下运行一个13B模型,我们大约只有10GB的内存来存储kv缓存。这意味着:
- 不能使用太大型的批次(尽管我们希望使用更大的批次大小以提高效率)
- 也不能处理太长的序列,尽管我们确实希望能够处理长度为100k的序列。
3 FlashAttention的策略
FlashAttention的核心策略包括:
- Tiling(平铺/切分):将注意力矩阵分解成更小的子矩阵,分别计算,确保每个子矩阵的大小适合SRAM(静态随机存取存储器)的存储能力,从而减少对HBM(高带宽内存)的访问。
- Recomputation(重算):在反向传播时,不存储所有中间状态,而是在需要时重新计算,节省内存。
- 分块SoftMax:解决标准SoftMax在分块计算中的问题,确保整个Flash Attention的正确性。
- 优化显存交换:减少SRAM与HBM之间的数据交换,加速计算。
这些策略共同作用,使FlashAttention在保持计算精度的同时,显著提高计算速度和内存效率
4 Ascend 上的FlashAttention
昇腾异构计算架构CANN针对昇腾AI处理器的片上内存和缓存大小,以及数据搬运通路,基于Ascend C算子编程语言优化实现FlashAttention融合算子,充分利用片上缓存,提升Attention处理性能。根据实测,在一些典型场景中CANN的FlashAttention算子相比小算子取得了5倍以上的性能提升,开发者可直接调用相关算子API接口使能大模型极致性能优化。
- 点赞
- 收藏
- 关注作者
评论(0)