多模态原理--条件扩散模型

举报
剑指南天 发表于 2026/06/06 13:46:28 2026/06/06
【摘要】 通过去噪扩散概率模型(Denoising Diffusion Probabilistic Models,DDPM)可以生成图像,但是生成的图像具有随机性,无法预测。条件扩散模型可以去噪扩散模型的基础上,在给定条件下,生成指定的图像。

1.概述

通过去噪扩散概率模型(Denoising Diffusion Probabilistic Models,DDPM)可以生成图像,但是生成的图像具有随机性,无法预测。条件扩散模型可以去噪扩散模型的基础上,在给定条件下,生成指定的图像。

2. 在 U-Net 网络添加条件

DDPM 中 U-Net 神经网络由的输入是添加噪声的图片和时间步。条件扩散模型额外添加一个条件影响预测的噪声。

2. U-Net 神经网络

2.1 U-Net 神经网络的结构

import time
from pathlib import Path

import matplotlib.pyplot as plt
import torch
from PIL import Image
from torch import nn
from torch.optim import Adam
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
from torchvision.datasets.mnist import MNIST
from tqdm import tqdm


# 按照时间步进行位置编码
def _pos_encoding(time_idx, output_dim, device='cpu'):
    t, D = time_idx, output_dim
    v = torch.zeros(D, device=device)
    _2i = torch.arange(0, D, step=2, device=device)
    div_term = torch.pow(10000, _2i / D)
    v[0::2] = torch.sin(t / div_term)
    if D % 2 == 1:
        div_term = div_term[:-1]
    v[1::2] = torch.cos(t / div_term)
    return v


# 批量位置编码
def pos_encoding(time_steps, output_dim, device='cpu'):
    batch_size = len(time_steps)
    v = torch.zeros(batch_size, output_dim, device=device)
    for i in range(batch_size):
        v[i] = _pos_encoding(time_steps[i], output_dim, device)
    return v

# 卷积块
class ConvBlock(nn.Module):
    def __init__(self, in_ch, out_ch, time_embed_dim):
        super().__init__()
        self.convs = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU()
        )
        # 将时间步嵌入转换成和图像相同的形状
        self.mlp = nn.Sequential(
            nn.Linear(time_embed_dim, in_ch),
            nn.ReLU(),
            nn.Linear(in_ch, in_ch)
        )

    def forward(self, x, v):
        N, C, _, _ = x.shape  # shape: (批次大小N,通道数量C,H,W)
        v = self.mlp(v)
        v = v.view(N, C, 1, 1)
        y = self.convs(x + v)
        return y


# U-Net 网络结构
class UNetCondition(nn.Module):
    def __init__(self, in_ch=3, time_embed_dim=100, num_labels=None):
        super().__init__()
        self.time_embed_dim = time_embed_dim
        self.num_labels = num_labels
        self.down1 = ConvBlock(in_ch, 64, time_embed_dim)
        self.down2 = ConvBlock(64, 128, time_embed_dim)
        self.bot1 = ConvBlock(128, 256, time_embed_dim)
        self.up2 = ConvBlock(128 + 256, 128, time_embed_dim)
        self.up1 = ConvBlock(128 + 64, 64, time_embed_dim)
        self.out = nn.Conv2d(64, in_ch, 1)

        self.max_pool = nn.MaxPool2d(2)
        self.up_sample = nn.Upsample(scale_factor=2, mode='bilinear')
        if self.num_labels is not None:
            self.label_emb = nn.Embedding(num_labels, time_embed_dim)

    def forward(self, x, time_steps, labels=None):
        v = pos_encoding(time_steps, self.time_embed_dim, x.device)

        if labels is not None:
            v += self.label_emb(labels)

        # x的形状是(B,C,H,W)
        x1 = self.down1(x, v)  # x1的形状是(B,C*2,H,W)
        x = self.max_pool(x1)  # x1的形状是(B,C*2,H/2,W/2)
        x2 = self.down2(x, v)  # x1的形状是(B,C*2*2,H/2,W/2)
        x = self.max_pool(x2)  # x1的形状是(B,C*2*2,H/4,W/4)
        x = self.bot1(x, v)  # x1的形状是(B,C*2*2*2,H/4,W/4)
        x = self.up_sample(x)  # x1的形状是(B,C*2*2*2,H/2,W/2)
        x = torch.cat([x, x2], dim=1)  # x1的形状是(B,C*2*2*2+C*2*2,H/2,W/2)
        x = self.up2(x, v)  # x1的形状是(B,C*2*2+C*2,H/2,W/2)
        x = self.up_sample(x)  # x1的形状是(B,C*2*2+C*2,H,W)
        x = torch.cat([x, x1], dim=1)  # x1的形状是(B,C*2*2*2,H,W)
        x = self.up1(x, v)  # x1的形状是(B,C*2*2,H,W)
        x = self.out(x)  # x1的形状是(B,C,H,W)
        return x

2.2 模型训练

class ImagesDatasets(Dataset):
    def __init__(self, img_size):
        self.preprocess = transforms.Compose([
            transforms.Resize((img_size, img_size)),  # Resize the input image
            transforms.ToTensor(),  # Convert to torch tensor (scales data into [0,1])
            transforms.Lambda(lambda t: (t * 2) - 1),  # Scale data between [-1, 1]
        ])
        self.dataset = MNIST(root="./datasets", train=True, download=True, transform=self.preprocess)

        # 方差调度计划[β_1, β_2, β_3,...]
        self.betas = torch.linspace(0.0001, 0.02, num_timesteps).to(device=device)
        self.alphas = 1 - self.betas
        self.alpha_bars = torch.cumprod(self.alphas, dim=0)

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, item):
        # 在训练集或者测试集取出第i张图片
        img, target = self.dataset[item]
        img = img.to(device=device)
        T = torch.randint(1, num_timesteps + 1, (1,)).to(device=device)
        t_idx = T - 1
        alpha_bar = self.alpha_bars[t_idx]
        noise = torch.randn_like(img).to(device=device)
        x_t = torch.sqrt(alpha_bar) * img + torch.sqrt(1 - alpha_bar) * noise
        return x_t, target, noise, T


# 基础配置
ROOT_DIR = Path(__file__).parent.parent
device = 'cuda' if torch.cuda.is_available() else 'cpu'
log_dir = ROOT_DIR / 'logs' / 'DDPM'

# 超参数
img_size = 32
batch_size = 128
num_timesteps = 1000
epochs = 50
lr = 5e-4

num_labels = 10
images_datasets = ImagesDatasets(img_size)
data_loader = DataLoader(images_datasets, batch_size=batch_size, shuffle=True)
model = UNetCondition(in_ch=1, num_labels=num_labels).to(device)
optimizer = Adam(model.parameters(), lr=lr)
loss_f = nn.MSELoss()

with SummaryWriter(log_dir=str(log_dir / time.strftime('%Y-%m-%d_%H-%M-%S'))) as writer:
    for epoch in range(epochs):
        loss_sum = 0.0
        cnt = 0
        loss_avg = 0.0
        for x_t, labels, noise, t in tqdm(data_loader, "开始训练: "):
            x_t, labels, noise, t = x_t.to(device=device), labels.to(device=device), noise.to(device=device), t.to(
                device=device)
            noise_p = model(x_t, t, labels)
            optimizer.zero_grad()
            loss = loss_f(noise_p, noise)
            loss.backward()
            optimizer.step()
            loss_sum += loss.item()
            cnt += 1
            loss_avg = loss_sum / cnt
        writer.add_scalar('loss', loss_avg, epoch + 1)
        print(f'Epoch {epoch + 1} | Loss: {loss_avg}')

3. 条件扩散去噪模型

# 像素点恢复为图像
def reverse_to_img(x):
    x = (x + 1) / 2
    x = x * 255
    x = x.clamp(0, 255)
    x = x.to(torch.uint8)
    x = x.cpu()
    to_pil = transforms.ToPILImage()
    return to_pil(x)


# 使用两行十列的方式显示20张图片
def show_images(images, rows=2, cols=10):
    fig = plt.figure(figsize=(cols, rows))
    i = 0
    for r in range(rows):
        for c in range(cols):
            fig.add_subplot(rows, cols, i + 1)
            plt.imshow(images[i], cmap='gray')
            plt.axis('off')
            i += 1
    plt.show()
alphas = images_datasets.alphas
alpha_bars = images_datasets.alpha_bars

images = {}
labels = torch.arange(0,10).to(device=device)
# 按照训练数据的形状,采样一张白噪声图片x_1000出来
image_noise_pure = torch.randn((1, img_size, img_size)).to(device=device)
x = torch.stack([image_noise_pure]*10)

batch_size = 10
# for t = T, T-1, ..., 0
for i in tqdm(range(num_timesteps, 0, -1)):
    t = torch.tensor([i] * batch_size, dtype=torch.long).to(device=device)
    # 一步去噪,x_t --> x_{t-1}
    t_idx = t - 1  # alphas[0] is for t=1
    alpha = alphas[t_idx]
    alpha_bar = alpha_bars[t_idx]
    alpha_bar_prev = alpha_bars[t_idx - 1]

    N = alpha.size(0)
    alpha = alpha.view(N, 1, 1, 1)
    alpha_bar = alpha_bar.view(N, 1, 1, 1)
    alpha_bar_prev = alpha_bar_prev.view(N, 1, 1, 1)

    model.eval()
    with torch.no_grad():
        eps = model(x, t,labels)
    noise = torch.randn_like(x).to(device=device)
    noise[t == 1] = 0  # no noise at t=1
    mu = (x - ((1 - alpha) / torch.sqrt(1 - alpha_bar)) * eps) / torch.sqrt(alpha)
    std = torch.sqrt((1 - alpha) * (1 - alpha_bar_prev) / (1 - alpha_bar))
    # x_{t-1}
    x = mu + noise * std
    if (i - 1) % 50 == 0:
        for n in labels.tolist():
            if n not in images:
                images[n] = [reverse_to_img(x[n])]
            else:
                images[n].append(reverse_to_img(x[n]))
for n in labels.tolist():
    show_images(images[n])







【版权声明】本文为华为云社区用户原创内容,未经允许不得转载,如需转载请自行联系原作者进行授权。如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱: cloudbbs@huaweicloud.com
  • 点赞
  • 收藏
  • 关注作者

评论(0

0/1000
抱歉,系统识别当前为高风险访问,暂不支持该操作

全部回复

上滑加载中

设置昵称

在此一键设置昵称,即可参与社区互动!

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。