多模态原理--CLIP
1.概述
Vision Transformer 证明将文本、图像、音频等经过处理 Embedding 作为 Transformer 模型的输入,通过训练可以使模型同时具有理解并融合多种不同类型信息的能力。CLIP 则是建立在Vision Transformer的基础上,预测图像和文本的相似度。具体来说就是 CLIP 通过大量的图像和文本对的学习,通过图像和文本特征向量之间的余弦相似度来预测这一点。

2. 文本分词器
英文按照字符分词,使用的词表是 ascii 码,所以词表的大小是256。
def tokenizer(text, encode=True,max_seq_length=32):
if encode:
out = chr(2) + text + chr(3) # 添加 SOT token 和 EOT token
out = out + chr(0)*(max_seq_length - len(out)) # 添加 Padding 字符
out = torch.tensor([ord(c) for c in out]) # 对文本进行编码
mask = (out>0).to(torch.int)
# mask 为什么需要形成方阵?
mask = mask.expand(max_seq_length, max_seq_length)
else:
# 将input_ids解码为text文本
out = "".join([chr(x) for x in text[1:text.index(0)-1]])
mask = None
return out, mask
2.1 为什么需要 padding 填充?
为了提高计算效率,训练数据都是需要多条数据分为一批进行并行计算。为了支持并行计算,所有需要保证同一批数据的形状一致。
2.2 为什么需要 mask ?
为了确保被掩码的数据不参与计算。
2.3 mask 的形状为什么是方阵?
注意力机制中计算注意力权重过程中,被 padding 的向量不应该贡献权重。比如 q4 和 k4 是被padding的向量,所以应该标示为负无穷,在参加softmax计算的时候,不会贡献权重。如下图,参与 softmax 计算的输入是向量长度X向量长度的方阵,所以mask也是相同形状方阵。


3. 整理数据
# 整理数据
# 数据来自手写数字识别,需要将其整理为图文对。
# 将图像标签整理为文本,比如手写 0 的标签,就需要整理为文本 “An image of 0”
class HandWritingMNIST(Dataset):
def __init__(self, train=True, captions_map=None):
# 从网络加载数据,并保存
self.dataset = MNIST(root="./datasets", train=train, download=True, transform=T.ToTensor())
# 数据标签对应的文本信息
self.captions = captions_map
def __len__(self):
return len(self.dataset)
def __getitem__(self, i):
# 在训练集或者测试集取出第i张图片
img, target = self.dataset[i]
# 图片对应的文本,以及文本的掩码
cap, mask = tokenizer(self.captions[target])
return img, target, cap, mask
4. 位置编码
# 位置编码
class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_seq_length):
super().__init__()
# 初始化编码矩阵 (max_len, d_model)
pe = torch.zeros(size=(max_seq_length, d_model))
# 当前词在序列中的位置 (max_len, 1)
pos = torch.arange(0, max_seq_length).unsqueeze(1)
# 表示公式中2i (d_model/2, )
_2i = torch.arange(0, d_model, 2)
# 计算10000**(2i/d_model) (d_model/2, )
div_term = torch.pow(10000, (_2i / d_model))
# 按奇偶数维度计算位置编码值 (max_len, d_model)
pe[:, 0::2] = torch.sin(pos / div_term)
if d_model % 2 == 1:
_2i1 = torch.arange(0, d_model-1, 2)
div_term = torch.pow(10000, (_2i1 / d_model))
pe[:, 1::2] = torch.cos(pos / div_term)
self.register_buffer("pe", pe)
def forward(self, x):
# 将位置编码添加到嵌入中
x = x + self.pe
return x
5. Encoder 模型
# 注意力头
class AttentionHead(nn.Module):
def __init__(self, d_model, head_size):
super().__init__()
self.head_size = head_size
self.query = nn.Linear(d_model, head_size)
self.key = nn.Linear(d_model, head_size)
self.value = nn.Linear(d_model, head_size)
def forward(self, x, mask):
# 计算Q, K, V
Q = self.query(x)
K = self.key(x)
V = self.value(x)
# Q和K的点积
attention = Q @ K.transpose(-2, -1)
# 缩放
attention = attention / (self.head_size ** 0.5)
if mask is not None:
attention = attention.masked_fill(mask == 0, float("-inf"))
attention = torch.softmax(attention, dim=-1)
attention = attention @ V
return attention
# 多注意头
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, n_heads):
super().__init__()
self.head_size = d_model // n_heads
self.W_o = nn.Linear(d_model, d_model)
self.heads = nn.ModuleList([AttentionHead(d_model, self.head_size) for _ in range(n_heads)])
def forward(self, x):
# 拼接多个注意力头
out = torch.cat([head(x) for head in self.heads], dim=-1)
out = self.W_o(out)
return out
# 多注意头 + 全连接层 + 层归一化和残差链接
class TransformerEncoder(nn.Module):
def __init__(self, d_model, n_heads, r_mlp=4):
super().__init__()
self.d_model = d_model
self.n_heads = n_heads
# 层归一化
self.ln1 = nn.LayerNorm(d_model)
# 多头注意力
self.mha = MultiHeadAttention(d_model, n_heads)
# 层归一化
self.ln2 = nn.LayerNorm(d_model)
# MLP
self.mlp = nn.Sequential(
nn.Linear(d_model, d_model * r_mlp),
nn.GELU(),
nn.Linear(d_model * r_mlp, d_model))
def forward(self, x):
# 第一次层归一化之后的残差
out = x + self.mha(self.ln1(x))
# 第二次层归一化之后的残差
out = out + self.mlp(self.ln2(out))
return out
6. 构建文本编码器
对于文本编码器,使用常规Transformer Encoder 模型。在输出 Transformer 的结果之前,将文本特征映射到文本和图像联合向量空间中,用于使用点积比较相似性。为了后期点积比较的方便,在这里将文本嵌入向量归一化,模长为1。
# 文本 Encoder 模型,并将EOT向量映射到维度为 emb_dim 的向量上
class TextEncoder(nn.Module):
def __init__(self, vocab_size, width, max_seq_length, n_heads, n_layers, emb_dim):
super().__init__()
self.max_seq_length = max_seq_length
self.encoder_embedding = nn.Embedding(vocab_size, width)
self.positional_embedding = PositionalEncoding(width, max_seq_length)
self.encoder = nn.ModuleList([TransformerEncoder(width, n_heads) for _ in range(n_layers)])
# 可学习投影(projection)
self.projection = nn.Linear(width, emb_dim, bias=False)
def forward(self, text, mask):
# 文本嵌入
x = self.encoder_embedding(text)
# 位置嵌入
x = self.positional_embedding(x)
# Transformer编码器
for encoder_layer in self.encoder:
x = encoder_layer(x, mask=mask)
# 从EOT的嵌入抽取特征
x = x[
torch.arange(text.shape[0]), # 批次中数据的索引
torch.sub(torch.sum(mask[:, 0], dim=1), 1) # 取出掩码mask矩阵的第0行,加和再减1,就得到了EOT的索引
]
# 将文本特征嵌入到联合嵌入空间中(多模态嵌入空间)
# 文本编码器输出的张量的维度和图像编码器输出的张量的维度必须一致
if self.projection is not None:
x = self.projection(x)
# 除以向量的模长或者范式
x = x / torch.norm(x, dim=-1, keepdim=True)
return x
7. 构建图像编码器
# 图像 Encoder 模型,并将 cls_token 向量映射到维度为 emb_dim 的向量上
class ImageEncoder(nn.Module):
def __init__(self, width, img_size, patch_size, n_channels, n_layers, n_heads, emb_dim):
super().__init__()
assert img_size[0] % patch_size[0] == 0 \
and img_size[1] % patch_size[1] == 0, \
"img_size必须能被patch_size整除"
assert width % n_heads == 0, \
"width必须能被n_heads整除"
self.n_patches = (img_size[0] * img_size[1]) // (patch_size[0] * patch_size[1])
self.max_seq_length = self.n_patches + 1
self.linear_project = nn.Conv2d(n_channels, width, kernel_size=patch_size, stride=patch_size)
self.cls_token = nn.Parameter(torch.randn(1, 1, width))
self.positional_embedding = PositionalEncoding(width, self.max_seq_length)
self.encoder = nn.ModuleList([
TransformerEncoder(width, n_heads)
for _ in range(n_layers)
])
# 可学习的投影
self.projection = nn.Linear(width, emb_dim, bias=False)
def forward(self, x):
# 补丁嵌入
x = self.linear_project(x)
x = x.flatten(2).transpose(1, 2)
# 位置嵌入
x = torch.cat((self.cls_token.expand(x.size()[0], -1, -1), x), dim=1)
x = self.positional_embedding(x)
# Transformer编码器
for encoder_layer in self.encoder:
x = encoder_layer(x)
# 获取类别token
x = x[:, 0, :]
# 多模态嵌入
# 保证文本编码器的输出的维度和图像编码器的输出的维度相等
if self.projection is not None:
x = self.projection(x)
x = x / torch.norm(x, dim=-1, keepdim=True)
return x
8. CLIP模型
class CLIP(nn.Module):
def __init__(self, emb_dim, vit_width, img_size, patch_size, n_channels, vit_layers,
vit_heads, vocab_size, text_width, max_seq_length, text_heads, text_layers):
super().__init__()
self.image_encoder = ImageEncoder(vit_width, img_size, patch_size, n_channels, vit_layers, vit_heads, emb_dim)
self.text_encoder = TextEncoder(vocab_size, text_width, max_seq_length, text_heads, text_layers, emb_dim)
# 可学习温度
self.temperature = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
def forward(self, image, text, mask):
# Iₑ是图像嵌入,形状 [B, D=emb_dim]
I_e = self.image_encoder(image)
# Tₑ是文本嵌入,形状 [B, D=emb_dim]
T_e = self.text_encoder(text, mask=mask)
# 缩放逐点余弦相似度[n, n]
# 形状 I_e @ T_e^T : [B, D] @ [D, B] --> [B, B]
logits = (I_e @ T_e.transpose(-2, -1)) * torch.exp(self.temperature)
# 对称损失函数 labels形状为[B],值为 [0, 1, 2, ..., B-1]
labels = torch.arange(logits.shape[0]).to(device)
# 从文本 --> 图像方向,以文本嵌入 T₃ 为例子,
# 交叉熵损失的目标是让 T₃⋅I₃ 越大越好
loss_i = nn.functional.cross_entropy(logits.transpose(-2, -1), labels)
# 从图像 --> 文本方向,以图像嵌入 I₃ 为例子,
# 交叉熵损失的目标是让 I₃⋅T₃ 越大越好
loss_t = nn.functional.cross_entropy(logits, labels)
# 两个方向的损失求平均值
loss = (loss_i + loss_t) / 2
return loss
9. 模型训练
# 基础配置
ROOT_DIR = Path(__file__).parent.parent
device = 'cuda' if torch.cuda.is_available() else 'cpu'
log_dir = ROOT_DIR / 'logs'
# 超参数配置
emb_dim = 32 # 文本编码器和图像编码器输出的张量的维度
vit_width = 9 # 图像编码器的嵌入的宽度
img_size = (28, 28)
patch_size = (14, 14)
n_channels = 1
vit_layers = 3 # 图像编码器中编码器的层的数量
vit_heads = 3 # 图像编码器中注意力头的数量
vocab_size = 256 # 词汇表大小
text_width = 32 # 文本编码器的嵌入的宽度
max_seq_length = 32 # 最大序列长度
text_heads = 8 # 文本编码器中注意力头的数量
text_layers = 4 # 文本编码器中编码器的层数
lr = 1e-3 # 学习率
epochs = 10
batch_size = 128
# 图片 label 和文本对应关系
captions_dict = {
0: "An image of 0",
1: "An image of 1",
2: "An image of 2",
3: "An image of 3",
4: "An image of 4",
5: "An image of 5",
6: "An image of 6",
7: "An image of 7",
8: "An image of 8",
9: "An image of 9"
}
# 加载数据
train_set = HandWritingMNIST(train=True, captions_map=captions_dict)
test_set = HandWritingMNIST(train=False, captions_map=captions_dict)
# 数据分批
train_loader = DataLoader(train_set, shuffle=True, batch_size=batch_size)
test_loader = DataLoader(test_set, shuffle=False, batch_size=batch_size)
# 模型初始化
model = CLIP(emb_dim, vit_width, img_size, patch_size, n_channels, vit_layers, vit_heads,
vocab_size, text_width, max_seq_length, text_heads, text_layers).to(device)
# 优化器
optimizer = optim.Adam(model.parameters(), lr=lr)
best_loss = np.inf
# 开始训练
with SummaryWriter(log_dir=str(log_dir / time.strftime('%Y-%m-%d_%H-%M-%S'))) as writer:
for epoch in range(epochs):
for img, _, cap, mask in train_loader:
img, cap, mask = img.to(device), cap.to(device), mask.to(device)
loss = model(img, cap, mask)
optimizer.zero_grad()
loss.backward()
optimizer.step()
print(f"Epoch [{epoch + 1}/{epochs}], Batch Loss: {loss.item():.3f}")
# 保存模型
if loss.item() <= best_loss:
best_loss = loss.item()
torch.save(model.state_dict(), "./clip.pt")
print("模型已经保存...")
writer.add_scalar('loss', loss.item(), epoch + 1)


10. 模型验证
# 加载最好的模型
model = CLIP(emb_dim, vit_width, img_size, patch_size, n_channels, vit_layers, vit_heads,
vocab_size, text_width, max_seq_length, text_heads, text_layers).to(device)
model.load_state_dict(torch.load("./clip.pt", map_location=device))
correct, total = 0, 0
caps = captions_dict.values()
caps_list = []
mask_list = []
for cap in caps:
cap_en, mask = tokenizer(cap, max_seq_length=32)
caps_list.append(cap_en.unsqueeze(0))
mask_list.append(mask.unsqueeze(0))
caps_tensor = torch.cat(caps_list, dim=0).to(device=device)
mask_tensor = torch.cat(mask_list, dim=0).to(device=device)
with torch.no_grad():
for img, target, _, _ in test_loader:
img, target = img.to(device), target.to(device)
# 使用clip模型中的图像编码器对图像抽取特征i
image_features = model.image_encoder(img)
# 使用clip模型中的文本编码器对文本抽取特征
text_features = model.text_encoder(caps_tensor, mask=mask_tensor)
# 归一化
image_features /= image_features.norm(dim=-1, keepdim=True)
text_features /= text_features.norm(dim=-1, keepdim=True)
# I_e @ T_e^T
similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1)
_, indices = torch.max(similarity, 1)
correct += (target == indices).sum()
total += target.size()[0]
print(f'\n预测准确率: {100 * correct // total} %')

11. 总结:文本或者图像的编码器模型都参考了 BERT 的单句分类任务的处理,可以实现文搜图的功能。具体步骤:
①将数据库中的所有图片使用clip的图像编码器进行抽取特征
②将抽取的图片特征存入向量数据库
③将输入的搜索文本通过clip的文本编码器抽取文本特征
④使用余弦相似度找出向量数据库中和文本特征最接近的几张图片
- 点赞
- 收藏
- 关注作者
评论(0)