多模态实践--扩散模型的逆向扩散代码实践
【摘要】 反向扩散过程是从一张完全高斯噪声图片中,逐步去除噪声,来生成一张图片。但是噪声的数据是很难直接获得的,所以噪声使用神经网络来模拟。
1.概述
反向扩散过程是从一张完全的高斯噪声图片中,逐步去除噪声,来生成一张图片。但是噪声的数据是很难直接获得的,所以噪声使用神经网络来模拟。
2. U-Net 神经网络
2.1 U-Net 神经网络的结构
U-Net的神经网络结构由卷积块、池化层、上采样层和残差连接构成。在代码中简化了这一过程,但是思想和 U-Net 网络一致。

# 按照时间步进行位置编码
def _pos_encoding(time_idx, output_dim, device='cpu'):
t, D = time_idx, output_dim
v = torch.zeros(D, device=device)
_2i = torch.arange(0, D, step=2, device=device)
div_term = torch.pow(10000, _2i / D)
v[0::2] = torch.sin(t / div_term)
if D % 2 == 1:
div_term = div_term[:-1]
v[1::2] = torch.cos(t / div_term)
return v
# 批量位置编码
def pos_encoding(time_steps, output_dim, device='cpu'):
batch_size = len(time_steps)
v = torch.zeros(batch_size, output_dim, device=device)
for i in range(batch_size):
v[i] = _pos_encoding(time_steps[i], output_dim, device)
return v
# 卷积块
class ConvBlock(nn.Module):
def __init__(self, in_ch, out_ch, time_embed_dim):
super().__init__()
self.convs = nn.Sequential(
nn.Conv2d(in_ch, out_ch, 3, padding=1),
nn.BatchNorm2d(out_ch),
nn.ReLU(),
nn.Conv2d(out_ch, out_ch, 3, padding=1),
nn.BatchNorm2d(out_ch),
nn.ReLU()
)
# 将时间步嵌入转换成和图像相同的形状
self.mlp = nn.Sequential(
nn.Linear(time_embed_dim, in_ch),
nn.ReLU(),
nn.Linear(in_ch, in_ch)
)
# v是时间步嵌入
def forward(self, x, v):
N, C, _, _ = x.shape # shape: (批次大小N,通道数量C,高H,宽W)
v = self.mlp(v)
v = v.view(N, C, 1, 1)
y = self.convs(x + v)
return y
# U-Net 网络结构
class UNet(nn.Module):
def __init__(self, in_ch=3, time_embed_dim=100):
super().__init__()
self.time_embed_dim = time_embed_dim
self.down1 = ConvBlock(in_ch, 64, time_embed_dim)
self.down2 = ConvBlock(64, 128, time_embed_dim)
self.bot1 = ConvBlock(128, 256, time_embed_dim)
self.up2 = ConvBlock(128 + 256, 128, time_embed_dim)
self.up1 = ConvBlock(128 + 64, 64, time_embed_dim)
self.out = nn.Conv2d(64, in_ch, 1)
self.max_pool = nn.MaxPool2d(2)
self.up_sample = nn.Upsample(scale_factor=2, mode='bilinear')
def forward(self, x, time_steps):
v = pos_encoding(time_steps, self.time_embed_dim, x.device)
# x的形状是(B,C,H,W)
x1 = self.down1(x, v) # x1的形状是(B,C*2,H,W)
x = self.max_pool(x1) # x的形状是(B,C*2,H/2,W/2)
x2 = self.down2(x, v) # x2的形状是(B,C*2*2,H/2,W/2)
x = self.max_pool(x2) # x的形状是(B,C*2*2,H/4,W/4)
x = self.bot1(x, v) # x的形状是(B,C*2*2*2,H/4,W/4)
x = self.up_sample(x) # x的形状是(B,C*2*2*2,H/2,W/2)
x = torch.cat([x, x2], dim=1) # x的形状是(B,C*2*2*2+C*2*2,H/2,W/2)
x = self.up2(x, v) # x的形状是(B,C*2*2+C*2,H/2,W/2)
x = self.up_sample(x) # x的形状是(B,C*2*2+C*2,H,W)
x = torch.cat([x, x1], dim=1) # x的形状是(B,C*2*2*2,H,W)
x = self.up1(x, v) # x的形状是(B,C*2*2,H,W)
x = self.out(x) # x的形状是(B,C,H,W)
return x
2.2 模型训练

class ImagesDatasets(Dataset):
def __init__(self, img_size):
self.preprocess = transforms.Compose([
transforms.Resize((img_size, img_size)), # Resize the input image
transforms.ToTensor(), # Convert to torch tensor (scales data into [0,1])
transforms.Lambda(lambda t: (t * 2) - 1), # Scale data between [-1, 1]
])
image_trump = preprocess(Image.open("trump.jpeg"))
image_biden = preprocess(Image.open("biden.jpeg"))
# 将两张图片各复制128次,然后变成一个批次
self.images = torch.stack([image_trump] * (batch_size * 2) + [image_biden] * (batch_size * 2)).to(device=device)
self.T = torch.randint(1, num_timesteps + 1, (len(self.images),)).to(device=device)
# 方差调度计划[β_1, β_2, β_3,...]
self.betas = torch.linspace(beta_start, beta_end, num_timesteps).to(device=device)
self.alphas = 1 - self.betas
self.alpha_bars = torch.cumprod(self.alphas, dim=0)
def __len__(self):
return len(self.images)
def __getitem__(self, item):
t_idx = self.T - 1 # alpha_bars[0] is for t=1
alpha_bar = self.alpha_bars[t_idx] # (N,)
N = alpha_bar.size(0)
alpha_bar = alpha_bar.view(N, 1, 1, 1) # (N, 1, 1, 1)
noise = torch.randn_like(self.images, device=self.device)
x_t = torch.sqrt(alpha_bar) * self.images + torch.sqrt(1 - alpha_bar) * noise
return x_t[item], noise[item], self.T[item]
# 基础配置
ROOT_DIR = Path(__file__).parent.parent
device = 'cuda' if torch.cuda.is_available() else 'cpu'
log_dir = ROOT_DIR / 'logs' / 'U-Net'
# 超参数
img_size = 32
batch_size = 128
num_timesteps = 1000
epochs = 1000
lr = 5e-4
images_datasets = ImagesDatasets(img_size)
dataloader = DataLoader(images_datasets, batch_size=batch_size, shuffle=True)
model = UNet().to(device)
optimizer = Adam(model.parameters(), lr=lr)
loss_f = nn.MSELoss()
with SummaryWriter(log_dir=str(log_dir / time.strftime('%Y-%m-%d_%H-%M-%S'))) as writer:
for epoch in range(epochs):
loss_sum = 0.0
cnt = 0
loss_avg = 0.0
for x_t, noise, t in tqdm(dataloader, "开始训练: "):
x_t, noise, t = x_t.to(device=device), noise.to(device=device), t.to(device=device)
noise_p = model(x_t, t)
optimizer.zero_grad()
loss = loss_f(noise_p, noise)
loss.backward()
optimizer.step()
loss_sum += loss.item()
cnt += 1
loss_avg = loss_sum / cnt
writer.add_scalar('loss', loss_avg, epoch + 1)
print(f'Epoch {epoch + 1} | Loss: {loss_avg}')

3. Denoising Diffusion Probabilistic Models,去噪扩散概率模型
# 像素点恢复为图像
def reverse_to_img(x):
x = (x + 1) / 2
x = x * 255
x = x.clamp(0, 255)
x = x.to(torch.uint8)
x = x.cpu()
to_pil = transforms.ToPILImage()
return to_pil(x)
# 使用两行十列的方式显示20张图片
def show_images(images, rows=2, cols=10):
fig = plt.figure(figsize=(cols, rows))
i = 0
for r in range(rows):
for c in range(cols):
fig.add_subplot(rows, cols, i + 1)
plt.imshow(images[i], cmap='gray')
plt.axis('off')
i += 1
plt.show()
alphas = images_datasets.alphas
alpha_bars = images_datasets.alpha_bars
preprocess = images_datasets.preprocess
image_trump = preprocess(Image.open("trump.jpeg"))
images_1 = []
images_2 = []
# 按照训练数据的形状,采样一张白噪声图片x_1000出来
image_noise_pure = torch.randn_like(image_trump).to(device=device)
# 一张添加噪声的图片
image_noise_add = torch.sqrt(alpha_bars[999]) * image_trump + torch.sqrt(1-alpha_bars[999]) * image_noise_pure
x = torch.stack([image_noise_pure,image_noise_add])
batch_size = 2
# for t = T, T-1, ..., 0
for i in tqdm(range(num_timesteps, 0, -1)):
t = torch.tensor([i] * batch_size, dtype=torch.long).to(device=device)
# 一步去噪,x_t --> x_{t-1}
t_idx = t - 1 # alphas[0] is for t=1
alpha = alphas[t_idx]
alpha_bar = alpha_bars[t_idx]
alpha_bar_prev = alpha_bars[t_idx - 1]
N = alpha.size(0)
alpha = alpha.view(N, 1, 1, 1)
alpha_bar = alpha_bar.view(N, 1, 1, 1)
alpha_bar_prev = alpha_bar_prev.view(N, 1, 1, 1)
model.eval()
with torch.no_grad():
eps = model(x, t)
noise = torch.randn_like(x).to(device=device)
noise[t == 1] = 0 # no noise at t=1
mu = (x - ((1 - alpha) / torch.sqrt(1 - alpha_bar)) * eps) / torch.sqrt(alpha)
std = torch.sqrt((1 - alpha) * (1 - alpha_bar_prev) / (1 - alpha_bar))
# x_{t-1}
x = mu + noise * std
if (i-1) % 50 == 0:
images_1.append(reverse_to_img(x[0]))
images_2.append(reverse_to_img(x[1]))
show_images(images_1)
show_images(images_2)


【版权声明】本文为华为云社区用户原创内容,未经允许不得转载,如需转载请自行联系原作者进行授权。如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱:
cloudbbs@huaweicloud.com
- 点赞
- 收藏
- 关注作者
评论(0)