多模态原理-- DALL-E2 中的CLIP模型

举报
剑指南天 发表于 2026/06/07 16:57:59 2026/06/07
【摘要】 为了从文本创建扩散图像(文生图),我们将使用CLIP模型中的嵌入。从CLIP获得的文本嵌入用于调节先验模型,使其扩散相应的图像嵌入。然后,这些图像嵌入用于调节解码器模型,用来指导解码器生成对应的图像。

1.概述

去噪扩散概率模型(DDPM)是在正向扩散过程中将噪声添加到图像中,以便训练模型预测在反向扩散过程中应在特定时间步去除的噪声。在对图像进行去噪时,需要从纯噪声的图像开始,并在每个时间步迭代地去除模型预测的噪声,直到获得最终图像。为了让 DDPM 生成指定内容的图像,需要向模型输入图像嵌入(条件扩散模型)。在OpenAI 的DALL-E 2模型,输入的图片标题描述被先验模型(prior model)生成 CLIP 图像嵌入。CLIP图像嵌入将被解码器网络(U-Net 模型)作用于图像的生成。

2. CLIP模型

2.1 位置编码

import time
from pathlib import Path

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as T
from attr import dataclass
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchvision.datasets.mnist import MNIST


# 时间步编码
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_seq_length):
        super().__init__()
        # 初始化编码矩阵 (max_len, d_model)
        pe = torch.zeros(size=(max_seq_length, d_model))
        # 当前词在序列中的位置 (max_len, 1)
        pos = torch.arange(0, max_seq_length).unsqueeze(1)
        # 表示公式中2i (d_model/2, )
        _2i = torch.arange(0, d_model, 2)
        # 计算10000**(2i/d_model) (d_model/2, )
        div_term = torch.pow(10000, (_2i / d_model))
        # 按奇偶数维度计算位置编码值 (max_len, d_model)
        pe[:, 0::2] = torch.sin(pos / div_term)
        if d_model % 2 == 1:
            div_term = div_term[:-1]
        pe[:, 1::2] = torch.cos(pos / div_term)
        self.register_buffer("pe", pe)

    def forward(self, t):
        return t + self.pe

2.2 自注意头(非因果)

# 自注意力头(非因果)
class AttentionHead(nn.Module):
    # d_model 每个token的嵌入长度
    # head_size 每个头的嵌入长度或者向量长度
    def __init__(self, d_model, head_size):
        super().__init__()
        self.head_size = head_size
        self.query = nn.Linear(d_model, head_size)
        self.key = nn.Linear(d_model, head_size)
        self.value = nn.Linear(d_model, head_size)

    def forward(self, x, mask):
        # 计算Q, K, V
        Q = self.query(x)
        K = self.key(x)
        V = self.value(x)

        # QK的点积
        attention = Q @ K.transpose(-2, -1)

        # 缩放
        attention = attention / (self.head_size ** 0.5)
        # 损失计算只需要关注新生成的有效的部分
        if mask is not None:
            attention = attention.masked_fill(mask == 0, float("-inf"))
        score = torch.softmax(attention, dim=-1)
        attention = score @ V
        return attention

2.3 多注意力头

# 多注意头
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, n_heads):
        super().__init__()
        # 在标准的 Transformer 架构中,通常要求向量的维度 d_model是注意力头数的整数倍,
        # 但严格来说,这并非绝对的数学约束,而是由设计惯例和实现效率决定的。
        self.head_size = d_model // n_heads
        self.W_o = nn.Linear(d_model, d_model)
        self.heads = nn.ModuleList([AttentionHead(d_model, self.head_size) for _ in range(n_heads)])

    def forward(self, x, mask):
        # 拼接多个注意力头
        out = torch.cat([head(x, mask) for head in self.heads], dim=-1)
        out = self.W_o(out)
        return out

2.4 Transformer 的 Encoder结构

# 多注意头 + 全连接层 + 层归一化和残差链接
class TransformerEncoder(nn.Module):
    # d_model 每个token的嵌入长度
    # head_size 每个头的嵌入长度或者向量长度
    # r_mlp 多注意头在模型中的层数
    def __init__(self, d_model, n_heads, r_mlp=4):
        super().__init__()
        self.d_model = d_model
        self.n_heads = n_heads

        # 层归一化
        self.ln1 = nn.LayerNorm(d_model)

        # 多头注意力
        self.mha = MultiHeadAttention(d_model, n_heads)

        # 层归一化
        self.ln2 = nn.LayerNorm(d_model)

        # MLP
        self.mlp = nn.Sequential(
            nn.Linear(d_model, d_model * r_mlp),
            nn.GELU(),
            nn.Linear(d_model * r_mlp, d_model)
        )

    def forward(self, x, mask=None):
        # 第一次层归一化之后的残差
        out = x + self.mha(self.ln1(x), mask)
        # 第二次层归一化之后的残差
        out = out + self.mlp(self.ln2(out))
        return out

2.5 英文分词器

# 英文分词器
def tokenizer(text, encode=True, max_seq_length=32):
    if encode:
        out = chr(2) + text + chr(3)  # 添加 SOT token  EOT token
        out = out + chr(0) * (max_seq_length - len(out))  # 添加 Padding 字符
        out = torch.tensor([ord(c) for c in out])  # 对文本进行编码
        mask = (out > 0).to(torch.int)
        # 因为是作用于Q@k,所以是方阵
        mask = mask.expand(max_seq_length, max_seq_length)
    else:
        # input_ids解码为text文本
        out = "".join([chr(x) for x in text[1:text.index(0) - 1]])
        mask = None
    return out, mask

2.6 文本嵌入模型

# 文本 Encoder 模型,并将EOT向量映射到维度为 emb_dim 的向量上
class TextEncoder(nn.Module):
    def __init__(self, vocab_size, width, max_seq_length, n_heads, n_layers, emb_dim):
        super().__init__()
        self.max_seq_length = max_seq_length
        self.encoder_embedding = nn.Embedding(vocab_size, width)
        self.positional_embedding = PositionalEncoding(width, max_seq_length)
        self.encoder = nn.ModuleList([TransformerEncoder(width, n_heads) for _ in range(n_layers)])
        # 可学习投影(projection        self.projection = nn.Linear(width, emb_dim, bias=False)

    def forward(self, text, mask):
        # 文本嵌入
        x = self.encoder_embedding(text)
        # 位置嵌入
        x = self.positional_embedding(x)
        # Transformer编码器
        for encoder_layer in self.encoder:
            x = encoder_layer(x, mask=mask)
        # EOT的嵌入抽取特征
        x = x[
            torch.arange(text.shape[0]),  # 批次中数据的索引
            torch.sub(torch.sum(mask[:, 0], dim=1), 1)  # 取出掩码mask矩阵的第0行,加和再减1,就得到了EOT的索引
        ]
        # 将文本特征嵌入到联合嵌入空间中(多模态嵌入空间)
        # 文本编码器输出的张量的维度和图像编码器输出的张量的维度必须一致
        x = self.projection(x)
        # 除以向量的模长或者范式
        x = x / torch.norm(x, dim=-1, keepdim=True)
        return x

2.7 图像嵌入模型

# 图像 Encoder 模型,并将 cls_token 向量映射到维度为 emb_dim 的向量上
class ImageEncoder(nn.Module):
    def __init__(self, width, img_size, patch_size, n_channels, n_layers, n_heads, emb_dim):
        super().__init__()

        assert img_size[0] % patch_size[0] == 0 \
               and img_size[1] % patch_size[1] == 0, \
            "img_size必须能被patch_size整除"
        assert width % n_heads == 0, \
            "width必须能被n_heads整除"

        self.n_patches = (img_size[0] * img_size[1]) // (patch_size[0] * patch_size[1])
        self.max_seq_length = self.n_patches + 1
        self.linear_project = nn.Conv2d(n_channels, width, kernel_size=patch_size, stride=patch_size)
        self.cls_token = nn.Parameter(torch.randn(1, 1, width))
        self.positional_embedding = PositionalEncoding(width, self.max_seq_length)
        self.encoder = nn.ModuleList([
            TransformerEncoder(width, n_heads)
            for _ in range(n_layers)
        ])
        # 可学习的投影
        self.projection = nn.Linear(width, emb_dim, bias=False)

2.8 CLIP 模型

class CLIP(nn.Module):
    # emb_dim 文本嵌入维度
    # vit_width 图像嵌入的维度
    # img_size 原始图像的 hw 
    # patch_size 补丁的 hw
    # n_channels 原始图像的通道
    # vit_layers 图像编码器的多头注意力的层数
    # vit_heads 图像编码器的注意力的头数
    # vocab_size 文本编码器的词表
    # text_width 文本嵌入的维度
    # max_seq_length 文本序列的长度
    # text_heads 文本编码器的注意力头数
    # text_layers 文本编码器多头注意力的层数
    def __init__(self, emb_dim, vit_width, img_size, patch_size, n_channels, vit_layers,
                 vit_heads, vocab_size, text_width, max_seq_length, text_heads, text_layers):
        super().__init__()
        self.image_encoder = ImageEncoder(vit_width, img_size, patch_size, n_channels, vit_layers, vit_heads, emb_dim)
        self.text_encoder = TextEncoder(vocab_size, text_width, max_seq_length, text_heads, text_layers, emb_dim)
        # 可学习温度
        self.temperature = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))

    def forward(self, image, text, mask):
        # Iₑ是图像嵌入,形状 [B, D=emb_dim]
        I_e = self.image_encoder(image)
        # Tₑ是文本嵌入,形状 [B, D=emb_dim]
        T_e = self.text_encoder(text, mask=mask)

        # 缩放逐点余弦相似度[n, n]
        # 形状 I_e @ T_e^T : [B, D] @ [D, B] --> [B, B]
        logits = (I_e @ T_e.transpose(-2, -1)) * torch.exp(self.temperature)

        # 对称损失函数 labels形状为[B],值为 [0, 1, 2, ..., B-1]
        labels = torch.arange(logits.shape[0]).to(device)
        # 从文本 --> 图像方向,以文本嵌入 T₃ 为例子,
        # 交叉熵损失的目标是让 T₃I₃ 越大越好
        loss_i = nn.functional.cross_entropy(logits.transpose(-2, -1), labels)
        # 从图像 --> 文本方向,以图像嵌入 I₃ 为例子,
        # 交叉熵损失的目标是让 I₃T₃ 越大越好
        loss_t = nn.functional.cross_entropy(logits, labels)
        # 两个方向的损失求平均值
        loss = (loss_i + loss_t) / 2
        return loss

3. CLIP模型的训练

# 整理数据
# 数据来自手写数字识别,需要将其整理为图文对。
# 将图像标签整理为文本,比如手写 0 的标签,就需要整理为文本 “An image of 0”
class HandWritingMNIST(Dataset):
    def __init__(self, train=True, captions_map=None):
        # 从网络加载数据,并保存
        self.dataset = MNIST(root="./datasets", train=train, download=True, transform=T.ToTensor())
        # 数据标签对应的文本信息
        self.captions = captions_map

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, i):
        # 在训练集或者测试集取出第i张图片
        img, target = self.dataset[i]
        # 图片对应的文本,以及文本的掩码
        cap, mask = tokenizer(self.captions[target])
        return img, target, cap, mask

# 配置类
@dataclass
class ClipConfig:
    # 超参数配置
    emb_dim = 32  # 文本编码器和图像编码器输出的张量的维度
    vit_width = 9  # 图像编码器的嵌入的宽度
    img_size = (28, 28)
    patch_size = (14, 14)
    n_channels = 1
    vit_layers = 3  # 图像编码器中编码器的层的数量
    vit_heads = 3  # 图像编码器中注意力头的数量
    vocab_size = 256  # 词汇表大小
    text_width = 32  # 文本编码器的嵌入的宽度
    max_seq_length = 32  # 最大序列长度
    text_heads = 8  # 文本编码器中注意力头的数量
    text_layers = 4  # 文本编码器中编码器的层数
    lr = 1e-3  # 学习率
    epochs = 12
    batch_size = 128
    # 图片 label 和文本对应关系
    captions_dict = {
        0: "An image of a t-shirt/top",
        1: "An image of trousers",
        2: "An image of a pullover",
        3: "An image of a dress",
        4: "An image of a coat",
        5: "An image of a sandal",
        6: "An image of a shirt",
        7: "An image of a sneaker",
        8: "An image of a bag",
        9: "An image of an ankle boot"
    }
    # 基础配置
    ROOT_DIR = Path(__file__).parent.parent
    log_dir = ROOT_DIR / 'clip' / 'logs'
    device = 'cuda' if torch.cuda.is_available() else 'cpu'


if __name__ == '__main__':
    # 加载数据
    train_set = HandWritingMNIST(train=True, captions_map=ClipConfig.captions_dict)
    test_set = HandWritingMNIST(train=False, captions_map=ClipConfig.captions_dict)

    # 数据分批
    train_loader = DataLoader(train_set, shuffle=True, batch_size=ClipConfig.batch_size)
    test_loader = DataLoader(test_set, shuffle=False, batch_size=ClipConfig.batch_size)

    # 模型初始化
    model = CLIP(ClipConfig.emb_dim, ClipConfig.vit_width, ClipConfig.img_size, ClipConfig.patch_size,
                 ClipConfig.n_channels, ClipConfig.vit_layers, ClipConfig.vit_heads, ClipConfig.vocab_size,
                 ClipConfig.text_width, ClipConfig.max_seq_length, ClipConfig.text_heads, ClipConfig.text_layers).to(
        ClipConfig.device)

    # 优化器
    optimizer = optim.Adam(model.parameters(), lr=ClipConfig.lr)

    best_loss = np.inf
    # 开始训练
    with SummaryWriter(log_dir=str(ClipConfig.log_dir / time.strftime('%Y-%m-%d_%H-%M-%S'))) as writer:
        for epoch in range(ClipConfig.epochs):
            for img, _, cap, mask in train_loader:
                img, cap, mask = img.to(ClipConfig.device), cap.to(ClipConfig.device), mask.to(ClipConfig.device)
                loss = model(img, cap, mask)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

            print(f"Epoch [{epoch + 1}/{ClipConfig.epochs}], Batch Loss: {loss.item():.3f}")
            # 保存模型
            if loss.item() <= best_loss:
                best_loss = loss.item()
                torch.save(model.state_dict(), "./clip.pt")
                print("模型已经保存...")
            writer.add_scalar('loss', loss.item(), epoch + 1)

4. CLIP 的验证

# 加载最好的模型
model = CLIP(emb_dim, vit_width, img_size, patch_size, n_channels, vit_layers, vit_heads,
             vocab_size, text_width, max_seq_length, text_heads, text_layers).to(device)
model.load_state_dict(torch.load("./clip.pt", map_location=device))

correct, total = 0, 0
caps = captions_dict.values()
caps_list = []
mask_list = []
for cap in caps:
    cap_en, mask = tokenizer(cap, max_seq_length=32)
    caps_list.append(cap_en.unsqueeze(0))
    mask_list.append(mask.unsqueeze(0))
caps_tensor = torch.cat(caps_list, dim=0).to(device=device)
mask_tensor = torch.cat(mask_list, dim=0).to(device=device)

with torch.no_grad():
    for img, target, _, _ in test_loader:
        img, target = img.to(device), target.to(device)
        # 使用clip模型中的图像编码器对图像抽取特征i
        image_features = model.image_encoder(img)
        # 使用clip模型中的文本编码器对文本抽取特征
        text_features = model.text_encoder(caps_tensor, mask=mask_tensor)
        # 归一化
        image_features /= image_features.norm(dim=-1, keepdim=True)
        text_features /= text_features.norm(dim=-1, keepdim=True)
        # I_e @ T_e^T
        similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1)
        _, indices = torch.max(similarity, 1)
        correct += (target == indices).sum()
        total += target.size()[0]
print(f'\n预测准确率: {100 * correct // total} %')

【版权声明】本文为华为云社区用户原创内容,未经允许不得转载,如需转载请自行联系原作者进行授权。如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱: cloudbbs@huaweicloud.com
  • 点赞
  • 收藏
  • 关注作者

评论(0

0/1000
抱歉,系统识别当前为高风险访问,暂不支持该操作

全部回复

上滑加载中

设置昵称

在此一键设置昵称,即可参与社区互动!

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。