Vision Transformer中的Patch概念解析
简单来说,Patch就是ViT将图像分割成的一个个小图像块。你可以把它想象成将一张完整的拼图打散成一个个小拼图块,然后模型再对这些小拼图块进行处理。
1. 核心概念:什么是Patch?
在传统的卷积神经网络中,我们使用滑动窗口的卷积核来提取图像的局部特征。而ViT的思路完全不同,它借鉴了Transformer在NLP领域的成功经验。
- 类比NLP中的Token(词元):在NLP中,一句话会被切分成一个个单词或子词(例如:“I love AI” -> [“I”, “love”, “AI”]),这些就是Token。每个Token会被转换成一个数值向量(词嵌入)。
- 图像中的Patch:在ViT中,一张图像被看作是这些“视觉词元”的集合。我们将一张图像(例如 224x224像素)规则地切割成N个固定大小的正方形块(例如 16x16像素),每个块就是一个Patch。
一个简单的计算示例:
一张 224x224x3 的图像,使用 16x16 的Patch大小进行分割:
- 每行有 224 / 16 = 14 个Patch
- 每列有 224 / 16 = 14 个Patch
- 总共有 14 * 14 = 196 个Patch。
这196个Patch就是模型要处理的“句子”,而每个Patch就是其中的一个“单词”。
2. 为什么需要Patch?—— 动机与原因
将图像切成Patch是ViT能够工作的前提,主要原因如下:
- 适配Transformer架构:Transformer的核心是Self-Attention机制,它需要一串序列作为输入。图像是二维的网格结构,通过将其展平为一维的Patch序列,我们成功地将图像数据“翻译”成了Transformer能理解的语言。
- 降低计算复杂度:原始的Self-Attention机制的计算复杂度是输入序列长度的平方。如果对每个像素(224x224=50,176个像素)做Attention,计算量将是天文数字。通过切割成Patch(如196个),序列长度被极大地缩短,使得训练成为可能。
- 捕获局部到全局的信息:每个16x16的Patch本身就包含了局部区域的信息(比如眼睛的一部分、羽毛的纹理)。通过Transformer层深度的增加,模型能够逐步将这些局部信息整合,最终理解全局的语义。
3. 技术实现:从Patch到模型输入
这个过程是ViT的精髓,可以分为以下几步:
步骤 1:图像分块
将输入图像 H x W x C 分割成 N 个 P x P x C 的Patches。
N = (H * W) / (P * P),这就是最终序列的长度。
步骤 2:Patch展平与线性投影
将每个二维的Patch(P x P x C)展平成一个一维向量(长度为 P * P * C)。
然后,通过一个可训练的线性投影层(全连接层) 将这个向量映射到一个固定的维度 D。这个 D 就是Transformer模型隐藏层的大小。
这个线性投影层的作用至关重要:
- 它相当于NLP中的词嵌入层,将每个“视觉单词”转换成一个
D维的向量表示。 - 这个输出的向量被称为 Patch Embedding。
步骤 3:添加位置编码
由于Transformer本身不具备感知序列顺序的能力(它是置换不变的),而图像中Patch的相对位置信息又极其重要。因此,我们需要为每个Patch Embedding加上一个位置编码,来告诉模型每个Patch在原始图像中的位置。
步骤 4:添加[CLS] Token
仿照BERT,ViT在序列的开头添加了一个可学习的特殊标记,称为 [class] token。它的Embedding向量会经过所有Transformer层,其最终的输出状态被用作整个图像的聚合表示,用于最终的分类任务。
最终,模型的输入是一个由 [CLS] token + N个Patch Embeddings] 组成的序列,每个向量都是D维。
总结与类比
| NLP (BERT) | ViT | 作用 |
|---|---|---|
| 单词/子词 | 图像块 | 输入的基本单元 |
| 词嵌入层 | 线性投影层 | 将基本单元映射为向量 |
| 位置编码 | 位置编码 | 提供序列顺序信息 |
[CLS] Token |
[CLS] Token |
聚合全局信息用于分类 |
进阶思考:Patch大小的影响
- 大Patch(如32x32):序列长度短,计算效率高,但每个Patch包含的像素多,更偏向于捕获宏观的、粗糙的特征。可能会损失细粒度信息。
- 小Patch(如8x8, 4x4):序列长度长,计算成本高,但每个Patch更精细,能捕获更细节的纹理。模型性能通常更好,但需要更强的计算资源。
后来的模型如Swin Transformer引入了分层设计和窗口注意力,部分原因就是为了克服小Patch带来的长序列问题。
- 点赞
- 收藏
- 关注作者
评论(0)