多模态原理--条件扩散模型
【摘要】 通过去噪扩散概率模型(Denoising Diffusion Probabilistic Models,DDPM)可以生成图像,但是生成的图像具有随机性,无法预测。条件扩散模型可以去噪扩散模型的基础上,在给定条件下,生成指定的图像。
1.概述
通过去噪扩散概率模型(Denoising Diffusion Probabilistic Models,DDPM)可以生成图像,但是生成的图像具有随机性,无法预测。条件扩散模型可以去噪扩散模型的基础上,在给定条件下,生成指定的图像。

2. 在 U-Net 网络添加条件
DDPM 中 U-Net 神经网络由的输入是添加噪声的图片和时间步。条件扩散模型额外添加一个条件影响预测的噪声。

2. U-Net 神经网络
2.1 U-Net 神经网络的结构
import time
from pathlib import Path
import matplotlib.pyplot as plt
import torch
from PIL import Image
from torch import nn
from torch.optim import Adam
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
from torchvision.datasets.mnist import MNIST
from tqdm import tqdm
# 按照时间步进行位置编码
def _pos_encoding(time_idx, output_dim, device='cpu'):
t, D = time_idx, output_dim
v = torch.zeros(D, device=device)
_2i = torch.arange(0, D, step=2, device=device)
div_term = torch.pow(10000, _2i / D)
v[0::2] = torch.sin(t / div_term)
if D % 2 == 1:
div_term = div_term[:-1]
v[1::2] = torch.cos(t / div_term)
return v
# 批量位置编码
def pos_encoding(time_steps, output_dim, device='cpu'):
batch_size = len(time_steps)
v = torch.zeros(batch_size, output_dim, device=device)
for i in range(batch_size):
v[i] = _pos_encoding(time_steps[i], output_dim, device)
return v

# 卷积块
class ConvBlock(nn.Module):
def __init__(self, in_ch, out_ch, time_embed_dim):
super().__init__()
self.convs = nn.Sequential(
nn.Conv2d(in_ch, out_ch, 3, padding=1),
nn.BatchNorm2d(out_ch),
nn.ReLU(),
nn.Conv2d(out_ch, out_ch, 3, padding=1),
nn.BatchNorm2d(out_ch),
nn.ReLU()
)
# 将时间步嵌入转换成和图像相同的形状
self.mlp = nn.Sequential(
nn.Linear(time_embed_dim, in_ch),
nn.ReLU(),
nn.Linear(in_ch, in_ch)
)
def forward(self, x, v):
N, C, _, _ = x.shape # shape: (批次大小N,通道数量C,高H,宽W)
v = self.mlp(v)
v = v.view(N, C, 1, 1)
y = self.convs(x + v)
return y
# U-Net 网络结构
class UNetCondition(nn.Module):
def __init__(self, in_ch=3, time_embed_dim=100, num_labels=None):
super().__init__()
self.time_embed_dim = time_embed_dim
self.num_labels = num_labels
self.down1 = ConvBlock(in_ch, 64, time_embed_dim)
self.down2 = ConvBlock(64, 128, time_embed_dim)
self.bot1 = ConvBlock(128, 256, time_embed_dim)
self.up2 = ConvBlock(128 + 256, 128, time_embed_dim)
self.up1 = ConvBlock(128 + 64, 64, time_embed_dim)
self.out = nn.Conv2d(64, in_ch, 1)
self.max_pool = nn.MaxPool2d(2)
self.up_sample = nn.Upsample(scale_factor=2, mode='bilinear')
if self.num_labels is not None:
self.label_emb = nn.Embedding(num_labels, time_embed_dim)
def forward(self, x, time_steps, labels=None):
v = pos_encoding(time_steps, self.time_embed_dim, x.device)
if labels is not None:
v += self.label_emb(labels)
# x的形状是(B,C,H,W)
x1 = self.down1(x, v) # x1的形状是(B,C*2,H,W)
x = self.max_pool(x1) # x1的形状是(B,C*2,H/2,W/2)
x2 = self.down2(x, v) # x1的形状是(B,C*2*2,H/2,W/2)
x = self.max_pool(x2) # x1的形状是(B,C*2*2,H/4,W/4)
x = self.bot1(x, v) # x1的形状是(B,C*2*2*2,H/4,W/4)
x = self.up_sample(x) # x1的形状是(B,C*2*2*2,H/2,W/2)
x = torch.cat([x, x2], dim=1) # x1的形状是(B,C*2*2*2+C*2*2,H/2,W/2)
x = self.up2(x, v) # x1的形状是(B,C*2*2+C*2,H/2,W/2)
x = self.up_sample(x) # x1的形状是(B,C*2*2+C*2,H,W)
x = torch.cat([x, x1], dim=1) # x1的形状是(B,C*2*2*2,H,W)
x = self.up1(x, v) # x1的形状是(B,C*2*2,H,W)
x = self.out(x) # x1的形状是(B,C,H,W)
return x
2.2 模型训练
class ImagesDatasets(Dataset):
def __init__(self, img_size):
self.preprocess = transforms.Compose([
transforms.Resize((img_size, img_size)), # Resize the input image
transforms.ToTensor(), # Convert to torch tensor (scales data into [0,1])
transforms.Lambda(lambda t: (t * 2) - 1), # Scale data between [-1, 1]
])
self.dataset = MNIST(root="./datasets", train=True, download=True, transform=self.preprocess)
# 方差调度计划[β_1, β_2, β_3,...]
self.betas = torch.linspace(0.0001, 0.02, num_timesteps).to(device=device)
self.alphas = 1 - self.betas
self.alpha_bars = torch.cumprod(self.alphas, dim=0)
def __len__(self):
return len(self.dataset)
def __getitem__(self, item):
# 在训练集或者测试集取出第i张图片
img, target = self.dataset[item]
img = img.to(device=device)
T = torch.randint(1, num_timesteps + 1, (1,)).to(device=device)
t_idx = T - 1
alpha_bar = self.alpha_bars[t_idx]
noise = torch.randn_like(img).to(device=device)
x_t = torch.sqrt(alpha_bar) * img + torch.sqrt(1 - alpha_bar) * noise
return x_t, target, noise, T
# 基础配置
ROOT_DIR = Path(__file__).parent.parent
device = 'cuda' if torch.cuda.is_available() else 'cpu'
log_dir = ROOT_DIR / 'logs' / 'DDPM'
# 超参数
img_size = 32
batch_size = 128
num_timesteps = 1000
epochs = 50
lr = 5e-4
num_labels = 10
images_datasets = ImagesDatasets(img_size)
data_loader = DataLoader(images_datasets, batch_size=batch_size, shuffle=True)
model = UNetCondition(in_ch=1, num_labels=num_labels).to(device)
optimizer = Adam(model.parameters(), lr=lr)
loss_f = nn.MSELoss()
with SummaryWriter(log_dir=str(log_dir / time.strftime('%Y-%m-%d_%H-%M-%S'))) as writer:
for epoch in range(epochs):
loss_sum = 0.0
cnt = 0
loss_avg = 0.0
for x_t, labels, noise, t in tqdm(data_loader, "开始训练: "):
x_t, labels, noise, t = x_t.to(device=device), labels.to(device=device), noise.to(device=device), t.to(
device=device)
noise_p = model(x_t, t, labels)
optimizer.zero_grad()
loss = loss_f(noise_p, noise)
loss.backward()
optimizer.step()
loss_sum += loss.item()
cnt += 1
loss_avg = loss_sum / cnt
writer.add_scalar('loss', loss_avg, epoch + 1)
print(f'Epoch {epoch + 1} | Loss: {loss_avg}')

3. 条件扩散去噪模型
# 像素点恢复为图像
def reverse_to_img(x):
x = (x + 1) / 2
x = x * 255
x = x.clamp(0, 255)
x = x.to(torch.uint8)
x = x.cpu()
to_pil = transforms.ToPILImage()
return to_pil(x)
# 使用两行十列的方式显示20张图片
def show_images(images, rows=2, cols=10):
fig = plt.figure(figsize=(cols, rows))
i = 0
for r in range(rows):
for c in range(cols):
fig.add_subplot(rows, cols, i + 1)
plt.imshow(images[i], cmap='gray')
plt.axis('off')
i += 1
plt.show()
alphas = images_datasets.alphas
alpha_bars = images_datasets.alpha_bars
images = {}
labels = torch.arange(0,10).to(device=device)
# 按照训练数据的形状,采样一张白噪声图片x_1000出来
image_noise_pure = torch.randn((1, img_size, img_size)).to(device=device)
x = torch.stack([image_noise_pure]*10)
batch_size = 10
# for t = T, T-1, ..., 0
for i in tqdm(range(num_timesteps, 0, -1)):
t = torch.tensor([i] * batch_size, dtype=torch.long).to(device=device)
# 一步去噪,x_t --> x_{t-1}
t_idx = t - 1 # alphas[0] is for t=1
alpha = alphas[t_idx]
alpha_bar = alpha_bars[t_idx]
alpha_bar_prev = alpha_bars[t_idx - 1]
N = alpha.size(0)
alpha = alpha.view(N, 1, 1, 1)
alpha_bar = alpha_bar.view(N, 1, 1, 1)
alpha_bar_prev = alpha_bar_prev.view(N, 1, 1, 1)
model.eval()
with torch.no_grad():
eps = model(x, t,labels)
noise = torch.randn_like(x).to(device=device)
noise[t == 1] = 0 # no noise at t=1
mu = (x - ((1 - alpha) / torch.sqrt(1 - alpha_bar)) * eps) / torch.sqrt(alpha)
std = torch.sqrt((1 - alpha) * (1 - alpha_bar_prev) / (1 - alpha_bar))
# x_{t-1}
x = mu + noise * std
if (i - 1) % 50 == 0:
for n in labels.tolist():
if n not in images:
images[n] = [reverse_to_img(x[n])]
else:
images[n].append(reverse_to_img(x[n]))
for n in labels.tolist():
show_images(images[n])










【版权声明】本文为华为云社区用户原创内容,未经允许不得转载,如需转载请自行联系原作者进行授权。如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱:
cloudbbs@huaweicloud.com
- 点赞
- 收藏
- 关注作者
评论(0)