多模态实践--扩散模型的正向扩散代码实践
【摘要】 前向扩散过程逐步将高斯噪声添加到输入图像 𝑥0 中,总共会有 𝑇 步。该过程将产生一系列带噪声的图像样本 𝑥1, 𝑥2, …, 𝑥𝑇 。当 𝑇 → ∞ 时,最终结果将变成完全噪声图像
1.概述
前向扩散过程逐步将高斯噪声添加到输入图像 𝑥0 中,总共会有 𝑇 步。该过程将产生一系列带噪声的图像样本 𝑥1, 𝑥2, …, 𝑥𝑇 。当 𝑇 → ∞ 时,最终结果将变成完全噪声图像。

2. 前向扩散迭代公式的代码实践
如果 𝑧 ∼ 𝒩︀(𝜇, 𝜎2) 的话,那么正态分布可以写成如下公式:𝑧 = 𝜇 + 𝜎𝜀 其中 𝜀 ∼ 𝒩︀(0, 1)。

方差计划为 𝛽start = 0.0002, 𝛽end = 0.04,时间步是1000的线性计划。
import torch
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
def reverse_to_img(x):
x = x * 255
x = x.clamp(0, 255)
x = x.to(torch.uint8)
to_pil = transforms.ToPILImage()
return to_pil(x)
# 最大时间步
T = 1000
# 方差计划的起始值
beta_start = 0.0001
# 方差计划的结束值
beta_end = 0.02
# 线性计划
betas = torch.linspace(beta_start, beta_end, T)
image = plt.imread("./flower.png")
preprocess = transforms.ToTensor()
x = preprocess(image)
imgs = []
for t in range(T):
# 每添加一百次噪声,保存噪声图片
if t % 100 == 0:
img = reverse_to_img(x)
imgs.append(img)
beta = betas[t]
eps = torch.randn_like(x) # 生成和x形状相同的符合标准正态分布的噪声
# 按照迭代公式添加噪声到到图片中
x = torch.sqrt(1 - beta) * x + torch.sqrt(beta) * eps
# 使用两行五列的方式显示10张图片
plt.figure(figsize=(15, 6))
for i, img in enumerate(imgs[:10]):
plt.subplot(2, 5, i + 1)
plt.imshow(img)
plt.title(f"Noise: {i * 100}")
plt.axis("off")
plt.show()

3. 前向扩散闭合公式的代码实践
给定原始图片 𝑥0 和时间步 𝑡 直接采样出 𝑥𝑡 的公式:

import torch
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
def reverse_to_img(x):
x = x * 255
x = x.clamp(0, 255)
x = x.to(torch.uint8)
to_pil = transforms.ToPILImage()
return to_pil(x)
# 最大时间步
T = 1000
# 方差计划的起始值
beta_start = 0.0001
# 方差计划的结束值
beta_end = 0.02
# 线性计划
betas = torch.linspace(beta_start, beta_end, T)
alphas = 1.0 - betas
alpha_bars = torch.cumprod(alphas,dim=0)
image = plt.imread("./flower.png")
preprocess = transforms.ToTensor()
x_0= preprocess(image)
imgs = []
for t in range(0,1000,100):
eps = torch.randn_like(x_0) # 生成和x形状相同的符合标准正态分布的噪声
img = reverse_to_img(torch.sqrt(alpha_bars[t]) * x_0 + torch.sqrt(1-alpha_bars[t]) * eps)
# 按照闭合公式添加噪声到到图片中
imgs.append(img)
# 使用两行五列的方式显示10张图片
plt.figure(figsize=(15, 6))
for i, img in enumerate(imgs[:10]):
plt.subplot(2, 5, i + 1)
plt.imshow(img)
plt.title(f"Noise: {i * 100}")
plt.axis("off")
plt.show()

【版权声明】本文为华为云社区用户原创内容,未经允许不得转载,如需转载请自行联系原作者进行授权。如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱:
cloudbbs@huaweicloud.com
- 点赞
- 收藏
- 关注作者
评论(0)